mirror of
https://github.com/pgvector/pgvector.git
synced 2026-06-06 05:51:21 +08:00
582 lines
15 KiB
C
582 lines
15 KiB
C
#include "postgres.h"
|
|
|
|
#include <float.h>
|
|
#include <limits.h>
|
|
#include <math.h>
|
|
|
|
#include "access/genam.h"
|
|
#include "fmgr.h"
|
|
#include "ivfflat.h"
|
|
#include "miscadmin.h"
|
|
#include "utils/memutils.h"
|
|
#include "utils/relcache.h"
|
|
|
|
#if PG_VERSION_NUM >= 160000
|
|
#include "varatt.h"
|
|
#endif
|
|
|
|
/*
|
|
* Initialize with kmeans++
|
|
*
|
|
* https://theory.stanford.edu/~sergei/papers/kMeansPP-soda.pdf
|
|
*/
|
|
static void
|
|
InitCenters(Relation index, VectorArray samples, VectorArray centers, float *lowerBound)
|
|
{
|
|
FmgrInfo *procinfo;
|
|
Oid collation;
|
|
int64 j;
|
|
float *weight = palloc(samples->length * sizeof(float));
|
|
int numCenters = centers->maxlen;
|
|
int numSamples = samples->length;
|
|
|
|
procinfo = index_getprocinfo(index, 1, IVFFLAT_KMEANS_DISTANCE_PROC);
|
|
collation = index->rd_indcollation[0];
|
|
|
|
/* Choose an initial center uniformly at random */
|
|
VectorArraySet(centers, 0, VectorArrayGet(samples, RandomInt() % samples->length));
|
|
centers->length++;
|
|
|
|
for (j = 0; j < numSamples; j++)
|
|
weight[j] = FLT_MAX;
|
|
|
|
for (int i = 0; i < numCenters; i++)
|
|
{
|
|
double sum;
|
|
double choice;
|
|
|
|
CHECK_FOR_INTERRUPTS();
|
|
|
|
sum = 0.0;
|
|
|
|
for (j = 0; j < numSamples; j++)
|
|
{
|
|
Datum vec = PointerGetDatum(VectorArrayGet(samples, j));
|
|
double distance;
|
|
|
|
/* Only need to compute distance for new center */
|
|
/* TODO Use triangle inequality to reduce distance calculations */
|
|
distance = DatumGetFloat8(FunctionCall2Coll(procinfo, collation, vec, PointerGetDatum(VectorArrayGet(centers, i))));
|
|
|
|
/* Set lower bound */
|
|
lowerBound[j * numCenters + i] = distance;
|
|
|
|
/* Use distance squared for weighted probability distribution */
|
|
distance *= distance;
|
|
|
|
if (distance < weight[j])
|
|
weight[j] = distance;
|
|
|
|
sum += weight[j];
|
|
}
|
|
|
|
/* Only compute lower bound on last iteration */
|
|
if (i + 1 == numCenters)
|
|
break;
|
|
|
|
/* Choose new center using weighted probability distribution. */
|
|
choice = sum * RandomDouble();
|
|
for (j = 0; j < numSamples - 1; j++)
|
|
{
|
|
choice -= weight[j];
|
|
if (choice <= 0)
|
|
break;
|
|
}
|
|
|
|
VectorArraySet(centers, i + 1, VectorArrayGet(samples, j));
|
|
centers->length++;
|
|
}
|
|
|
|
pfree(weight);
|
|
}
|
|
|
|
/*
|
|
* Norm centers
|
|
*/
|
|
static void
|
|
NormCenters(const IvfflatTypeInfo * typeInfo, Oid collation, VectorArray centers)
|
|
{
|
|
MemoryContext normCtx = AllocSetContextCreate(CurrentMemoryContext,
|
|
"Ivfflat norm temporary context",
|
|
ALLOCSET_DEFAULT_SIZES);
|
|
MemoryContext oldCtx = MemoryContextSwitchTo(normCtx);
|
|
|
|
for (int j = 0; j < centers->length; j++)
|
|
{
|
|
Datum center = PointerGetDatum(VectorArrayGet(centers, j));
|
|
Datum newCenter = IvfflatNormValue(typeInfo, collation, center);
|
|
Size size = VARSIZE_ANY(DatumGetPointer(newCenter));
|
|
|
|
if (size > centers->itemsize)
|
|
elog(ERROR, "safety check failed");
|
|
|
|
memcpy(DatumGetPointer(center), DatumGetPointer(newCenter), size);
|
|
MemoryContextReset(normCtx);
|
|
}
|
|
|
|
MemoryContextSwitchTo(oldCtx);
|
|
MemoryContextDelete(normCtx);
|
|
}
|
|
|
|
/*
|
|
* Quick approach if we have no data
|
|
*/
|
|
static void
|
|
RandomCenters(Relation index, VectorArray centers, const IvfflatTypeInfo * typeInfo)
|
|
{
|
|
int dimensions = centers->dim;
|
|
FmgrInfo *normprocinfo = IvfflatOptionalProcInfo(index, IVFFLAT_KMEANS_NORM_PROC);
|
|
Oid collation = index->rd_indcollation[0];
|
|
float *x = (float *) palloc(sizeof(float) * dimensions);
|
|
|
|
/* Fill with random data */
|
|
while (centers->length < centers->maxlen)
|
|
{
|
|
Pointer center = VectorArrayGet(centers, centers->length);
|
|
|
|
for (int i = 0; i < dimensions; i++)
|
|
x[i] = (float) RandomDouble();
|
|
|
|
typeInfo->updateCenter(center, dimensions, x);
|
|
|
|
centers->length++;
|
|
}
|
|
|
|
if (normprocinfo != NULL)
|
|
NormCenters(typeInfo, collation, centers);
|
|
}
|
|
|
|
#ifdef IVFFLAT_MEMORY
|
|
/*
|
|
* Show memory usage
|
|
*/
|
|
static void
|
|
ShowMemoryUsage(MemoryContext context, Size estimatedSize)
|
|
{
|
|
elog(INFO, "total memory: %zu MB",
|
|
MemoryContextMemAllocated(context, true) / (1024 * 1024));
|
|
elog(INFO, "estimated memory: %zu MB", estimatedSize / (1024 * 1024));
|
|
}
|
|
#endif
|
|
|
|
/*
|
|
* Sum centers
|
|
*/
|
|
static void
|
|
SumCenters(VectorArray samples, float *agg, int *closestCenters, const IvfflatTypeInfo * typeInfo)
|
|
{
|
|
for (int j = 0; j < samples->length; j++)
|
|
{
|
|
float *x = agg + ((int64) closestCenters[j] * samples->dim);
|
|
|
|
typeInfo->sumCenter(VectorArrayGet(samples, j), x);
|
|
}
|
|
}
|
|
|
|
/*
|
|
* Update centers
|
|
*/
|
|
static void
|
|
UpdateCenters(float *agg, VectorArray centers, const IvfflatTypeInfo * typeInfo)
|
|
{
|
|
for (int j = 0; j < centers->length; j++)
|
|
{
|
|
float *x = agg + ((int64) j * centers->dim);
|
|
|
|
typeInfo->updateCenter(VectorArrayGet(centers, j), centers->dim, x);
|
|
}
|
|
}
|
|
|
|
/*
|
|
* Compute new centers
|
|
*/
|
|
static void
|
|
ComputeNewCenters(VectorArray samples, float *agg, VectorArray newCenters, int *centerCounts, int *closestCenters, FmgrInfo *normprocinfo, Oid collation, const IvfflatTypeInfo * typeInfo)
|
|
{
|
|
int dimensions = newCenters->dim;
|
|
int numCenters = newCenters->length;
|
|
int numSamples = samples->length;
|
|
|
|
/* Reset sum and count */
|
|
for (int j = 0; j < numCenters; j++)
|
|
{
|
|
float *x = agg + ((int64) j * dimensions);
|
|
|
|
for (int k = 0; k < dimensions; k++)
|
|
x[k] = 0.0;
|
|
|
|
centerCounts[j] = 0;
|
|
}
|
|
|
|
/* Increment sum of closest center */
|
|
SumCenters(samples, agg, closestCenters, typeInfo);
|
|
|
|
/* Increment count of closest center */
|
|
for (int j = 0; j < numSamples; j++)
|
|
centerCounts[closestCenters[j]] += 1;
|
|
|
|
/* Divide sum by count */
|
|
for (int j = 0; j < numCenters; j++)
|
|
{
|
|
float *x = agg + ((int64) j * dimensions);
|
|
|
|
if (centerCounts[j] > 0)
|
|
{
|
|
/* Double avoids overflow, but requires more memory */
|
|
/* TODO Update bounds */
|
|
for (int k = 0; k < dimensions; k++)
|
|
{
|
|
if (isinf(x[k]))
|
|
x[k] = x[k] > 0 ? FLT_MAX : -FLT_MAX;
|
|
}
|
|
|
|
for (int k = 0; k < dimensions; k++)
|
|
x[k] /= centerCounts[j];
|
|
}
|
|
else
|
|
{
|
|
/* TODO Handle empty centers properly */
|
|
for (int k = 0; k < dimensions; k++)
|
|
x[k] = RandomDouble();
|
|
}
|
|
}
|
|
|
|
/* Set new centers */
|
|
UpdateCenters(agg, newCenters, typeInfo);
|
|
|
|
/* Normalize if needed */
|
|
if (normprocinfo != NULL)
|
|
NormCenters(typeInfo, collation, newCenters);
|
|
}
|
|
|
|
/*
|
|
* Use Elkan for performance. This requires distance function to satisfy triangle inequality.
|
|
*
|
|
* We use L2 distance for L2 (not L2 squared like index scan)
|
|
* and angular distance for inner product and cosine distance
|
|
*
|
|
* https://www.aaai.org/Papers/ICML/2003/ICML03-022.pdf
|
|
*/
|
|
static void
|
|
ElkanKmeans(Relation index, VectorArray samples, VectorArray centers, const IvfflatTypeInfo * typeInfo)
|
|
{
|
|
FmgrInfo *procinfo;
|
|
FmgrInfo *normprocinfo;
|
|
Oid collation;
|
|
int dimensions = centers->dim;
|
|
int numCenters = centers->maxlen;
|
|
int numSamples = samples->length;
|
|
VectorArray newCenters;
|
|
float *agg;
|
|
int *centerCounts;
|
|
int *closestCenters;
|
|
float *lowerBound;
|
|
float *upperBound;
|
|
float *s;
|
|
float *halfcdist;
|
|
float *newcdist;
|
|
|
|
/* Calculate allocation sizes */
|
|
Size samplesSize = VECTOR_ARRAY_SIZE(samples->maxlen, samples->itemsize);
|
|
Size centersSize = VECTOR_ARRAY_SIZE(centers->maxlen, centers->itemsize);
|
|
Size newCentersSize = VECTOR_ARRAY_SIZE(numCenters, centers->itemsize);
|
|
Size aggSize = sizeof(float) * (int64) numCenters * dimensions;
|
|
Size centerCountsSize = sizeof(int) * numCenters;
|
|
Size closestCentersSize = sizeof(int) * numSamples;
|
|
Size lowerBoundSize = sizeof(float) * numSamples * numCenters;
|
|
Size upperBoundSize = sizeof(float) * numSamples;
|
|
Size sSize = sizeof(float) * numCenters;
|
|
Size halfcdistSize = sizeof(float) * numCenters * numCenters;
|
|
Size newcdistSize = sizeof(float) * numCenters;
|
|
|
|
/* Calculate total size */
|
|
Size totalSize = samplesSize + centersSize + newCentersSize + aggSize + centerCountsSize + closestCentersSize + lowerBoundSize + upperBoundSize + sSize + halfcdistSize + newcdistSize;
|
|
|
|
/* Check memory requirements */
|
|
/* Add one to error message to ceil */
|
|
if (totalSize > (Size) maintenance_work_mem * 1024L)
|
|
ereport(ERROR,
|
|
(errcode(ERRCODE_PROGRAM_LIMIT_EXCEEDED),
|
|
errmsg("memory required is %zu MB, maintenance_work_mem is %d MB",
|
|
totalSize / (1024 * 1024) + 1, maintenance_work_mem / 1024)));
|
|
|
|
/* Ensure indexing does not overflow */
|
|
if (numCenters * numCenters > INT_MAX)
|
|
elog(ERROR, "Indexing overflow detected. Please report a bug.");
|
|
|
|
/* Set support functions */
|
|
procinfo = index_getprocinfo(index, 1, IVFFLAT_KMEANS_DISTANCE_PROC);
|
|
normprocinfo = IvfflatOptionalProcInfo(index, IVFFLAT_KMEANS_NORM_PROC);
|
|
collation = index->rd_indcollation[0];
|
|
|
|
/* Allocate space */
|
|
/* Use float instead of double to save memory */
|
|
agg = palloc(aggSize);
|
|
centerCounts = palloc(centerCountsSize);
|
|
closestCenters = palloc(closestCentersSize);
|
|
lowerBound = palloc_extended(lowerBoundSize, MCXT_ALLOC_HUGE);
|
|
upperBound = palloc(upperBoundSize);
|
|
s = palloc(sSize);
|
|
halfcdist = palloc_extended(halfcdistSize, MCXT_ALLOC_HUGE);
|
|
newcdist = palloc(newcdistSize);
|
|
|
|
/* Initialize new centers */
|
|
newCenters = VectorArrayInit(numCenters, dimensions, centers->itemsize);
|
|
newCenters->length = numCenters;
|
|
|
|
#ifdef IVFFLAT_MEMORY
|
|
ShowMemoryUsage(MemoryContextGetParent(CurrentMemoryContext), totalSize);
|
|
#endif
|
|
|
|
/* Pick initial centers */
|
|
InitCenters(index, samples, centers, lowerBound);
|
|
|
|
/* Assign each x to its closest initial center c(x) = argmin d(x,c) */
|
|
for (int64 j = 0; j < numSamples; j++)
|
|
{
|
|
float minDistance = FLT_MAX;
|
|
int closestCenter = 0;
|
|
|
|
/* Find closest center */
|
|
for (int64 k = 0; k < numCenters; k++)
|
|
{
|
|
/* TODO Use Lemma 1 in k-means++ initialization */
|
|
float distance = lowerBound[j * numCenters + k];
|
|
|
|
if (distance < minDistance)
|
|
{
|
|
minDistance = distance;
|
|
closestCenter = k;
|
|
}
|
|
}
|
|
|
|
upperBound[j] = minDistance;
|
|
closestCenters[j] = closestCenter;
|
|
}
|
|
|
|
/* Give 500 iterations to converge */
|
|
for (int iteration = 0; iteration < 500; iteration++)
|
|
{
|
|
int changes = 0;
|
|
bool rjreset;
|
|
|
|
/* Can take a while, so ensure we can interrupt */
|
|
CHECK_FOR_INTERRUPTS();
|
|
|
|
/* Step 1: For all centers, compute distance */
|
|
for (int64 j = 0; j < numCenters; j++)
|
|
{
|
|
Datum vec = PointerGetDatum(VectorArrayGet(centers, j));
|
|
|
|
for (int64 k = j + 1; k < numCenters; k++)
|
|
{
|
|
float distance = 0.5 * DatumGetFloat8(FunctionCall2Coll(procinfo, collation, vec, PointerGetDatum(VectorArrayGet(centers, k))));
|
|
|
|
halfcdist[j * numCenters + k] = distance;
|
|
halfcdist[k * numCenters + j] = distance;
|
|
}
|
|
}
|
|
|
|
/* For all centers c, compute s(c) */
|
|
for (int64 j = 0; j < numCenters; j++)
|
|
{
|
|
float minDistance = FLT_MAX;
|
|
|
|
for (int64 k = 0; k < numCenters; k++)
|
|
{
|
|
float distance;
|
|
|
|
if (j == k)
|
|
continue;
|
|
|
|
distance = halfcdist[j * numCenters + k];
|
|
if (distance < minDistance)
|
|
minDistance = distance;
|
|
}
|
|
|
|
s[j] = minDistance;
|
|
}
|
|
|
|
rjreset = iteration != 0;
|
|
|
|
for (int64 j = 0; j < numSamples; j++)
|
|
{
|
|
bool rj;
|
|
|
|
/* Step 2: Identify all points x such that u(x) <= s(c(x)) */
|
|
if (upperBound[j] <= s[closestCenters[j]])
|
|
continue;
|
|
|
|
rj = rjreset;
|
|
|
|
for (int64 k = 0; k < numCenters; k++)
|
|
{
|
|
Datum vec;
|
|
float dxcx;
|
|
|
|
/* Step 3: For all remaining points x and centers c */
|
|
if (k == closestCenters[j])
|
|
continue;
|
|
|
|
if (upperBound[j] <= lowerBound[j * numCenters + k])
|
|
continue;
|
|
|
|
if (upperBound[j] <= halfcdist[closestCenters[j] * numCenters + k])
|
|
continue;
|
|
|
|
vec = PointerGetDatum(VectorArrayGet(samples, j));
|
|
|
|
/* Step 3a */
|
|
if (rj)
|
|
{
|
|
dxcx = DatumGetFloat8(FunctionCall2Coll(procinfo, collation, vec, PointerGetDatum(VectorArrayGet(centers, closestCenters[j]))));
|
|
|
|
/* d(x,c(x)) computed, which is a form of d(x,c) */
|
|
lowerBound[j * numCenters + closestCenters[j]] = dxcx;
|
|
upperBound[j] = dxcx;
|
|
|
|
rj = false;
|
|
}
|
|
else
|
|
dxcx = upperBound[j];
|
|
|
|
/* Step 3b */
|
|
if (dxcx > lowerBound[j * numCenters + k] || dxcx > halfcdist[closestCenters[j] * numCenters + k])
|
|
{
|
|
float dxc = DatumGetFloat8(FunctionCall2Coll(procinfo, collation, vec, PointerGetDatum(VectorArrayGet(centers, k))));
|
|
|
|
/* d(x,c) calculated */
|
|
lowerBound[j * numCenters + k] = dxc;
|
|
|
|
if (dxc < dxcx)
|
|
{
|
|
closestCenters[j] = k;
|
|
|
|
/* c(x) changed */
|
|
upperBound[j] = dxc;
|
|
|
|
changes++;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
/* Step 4: For each center c, let m(c) be mean of all points assigned */
|
|
ComputeNewCenters(samples, agg, newCenters, centerCounts, closestCenters, normprocinfo, collation, typeInfo);
|
|
|
|
/* Step 5 */
|
|
for (int j = 0; j < numCenters; j++)
|
|
newcdist[j] = DatumGetFloat8(FunctionCall2Coll(procinfo, collation, PointerGetDatum(VectorArrayGet(centers, j)), PointerGetDatum(VectorArrayGet(newCenters, j))));
|
|
|
|
for (int64 j = 0; j < numSamples; j++)
|
|
{
|
|
for (int64 k = 0; k < numCenters; k++)
|
|
{
|
|
float distance = lowerBound[j * numCenters + k] - newcdist[k];
|
|
|
|
if (distance < 0)
|
|
distance = 0;
|
|
|
|
lowerBound[j * numCenters + k] = distance;
|
|
}
|
|
}
|
|
|
|
/* Step 6 */
|
|
/* We reset r(x) before Step 3 in the next iteration */
|
|
for (int j = 0; j < numSamples; j++)
|
|
upperBound[j] += newcdist[closestCenters[j]];
|
|
|
|
/* Step 7 */
|
|
for (int j = 0; j < numCenters; j++)
|
|
VectorArraySet(centers, j, VectorArrayGet(newCenters, j));
|
|
|
|
if (changes == 0 && iteration != 0)
|
|
break;
|
|
}
|
|
}
|
|
|
|
/*
|
|
* Ensure no NaN or infinite values
|
|
*/
|
|
static void
|
|
CheckElements(VectorArray centers, const IvfflatTypeInfo * typeInfo)
|
|
{
|
|
float *scratch = palloc(sizeof(float) * centers->dim);
|
|
|
|
for (int i = 0; i < centers->length; i++)
|
|
{
|
|
for (int j = 0; j < centers->dim; j++)
|
|
scratch[j] = 0;
|
|
|
|
/* /fp:fast may not propagate NaN with MSVC, but that's alright */
|
|
typeInfo->sumCenter(VectorArrayGet(centers, i), scratch);
|
|
|
|
for (int j = 0; j < centers->dim; j++)
|
|
{
|
|
if (isnan(scratch[j]))
|
|
elog(ERROR, "NaN detected. Please report a bug.");
|
|
|
|
if (isinf(scratch[j]))
|
|
elog(ERROR, "Infinite value detected. Please report a bug.");
|
|
}
|
|
}
|
|
}
|
|
|
|
/*
|
|
* Ensure no zero vectors for cosine distance
|
|
*/
|
|
static void
|
|
CheckNorms(VectorArray centers, Relation index)
|
|
{
|
|
/* Check NORM_PROC instead of KMEANS_NORM_PROC */
|
|
FmgrInfo *normprocinfo = IvfflatOptionalProcInfo(index, IVFFLAT_NORM_PROC);
|
|
Oid collation = index->rd_indcollation[0];
|
|
|
|
if (normprocinfo == NULL)
|
|
return;
|
|
|
|
for (int i = 0; i < centers->length; i++)
|
|
{
|
|
double norm = DatumGetFloat8(FunctionCall1Coll(normprocinfo, collation, PointerGetDatum(VectorArrayGet(centers, i))));
|
|
|
|
if (norm == 0)
|
|
elog(ERROR, "Zero norm detected. Please report a bug.");
|
|
}
|
|
}
|
|
|
|
/*
|
|
* Detect issues with centers
|
|
*/
|
|
static void
|
|
CheckCenters(Relation index, VectorArray centers, const IvfflatTypeInfo * typeInfo)
|
|
{
|
|
if (centers->length != centers->maxlen)
|
|
elog(ERROR, "Not enough centers. Please report a bug.");
|
|
|
|
CheckElements(centers, typeInfo);
|
|
CheckNorms(centers, index);
|
|
}
|
|
|
|
/*
|
|
* Perform naive k-means centering
|
|
* We use spherical k-means for inner product and cosine
|
|
*/
|
|
void
|
|
IvfflatKmeans(Relation index, VectorArray samples, VectorArray centers, const IvfflatTypeInfo * typeInfo)
|
|
{
|
|
MemoryContext kmeansCtx = AllocSetContextCreate(CurrentMemoryContext,
|
|
"Ivfflat kmeans temporary context",
|
|
ALLOCSET_DEFAULT_SIZES);
|
|
MemoryContext oldCtx = MemoryContextSwitchTo(kmeansCtx);
|
|
|
|
if (samples->length == 0)
|
|
RandomCenters(index, centers, typeInfo);
|
|
else
|
|
ElkanKmeans(index, samples, centers, typeInfo);
|
|
|
|
CheckCenters(index, centers, typeInfo);
|
|
|
|
MemoryContextSwitchTo(oldCtx);
|
|
MemoryContextDelete(kmeansCtx);
|
|
}
|