mirror of
https://github.com/pgvector/pgvector.git
synced 2026-07-02 02:31:16 +08:00
Switched to support function for normalizing centers for k-means
This commit is contained in:
@@ -89,33 +89,27 @@ InitCenters(Relation index, VectorArray samples, VectorArray centers, float *low
|
||||
}
|
||||
|
||||
/*
|
||||
* Apply norm to vector
|
||||
* Norm centers
|
||||
*/
|
||||
static inline void
|
||||
ApplyNorm(FmgrInfo *normprocinfo, Oid collation, Datum value, IvfflatType type)
|
||||
static void
|
||||
NormCenters(FmgrInfo *normalizeprocinfo, Oid collation, VectorArray centers)
|
||||
{
|
||||
double norm = DatumGetFloat8(FunctionCall1Coll(normprocinfo, collation, value));
|
||||
MemoryContext normCtx = AllocSetContextCreate(CurrentMemoryContext,
|
||||
"Ivfflat norm temporary context",
|
||||
ALLOCSET_DEFAULT_SIZES);
|
||||
MemoryContext oldCtx = MemoryContextSwitchTo(normCtx);
|
||||
|
||||
/* TODO Handle zero norm */
|
||||
if (norm > 0)
|
||||
for (int j = 0; j < centers->maxlen; j++)
|
||||
{
|
||||
if (type == IVFFLAT_TYPE_VECTOR)
|
||||
{
|
||||
Vector *vec = DatumGetVector(value);
|
||||
Datum center = PointerGetDatum(VectorArrayGet(centers, j));
|
||||
Datum newCenter = IvfflatNormValue(normalizeprocinfo, collation, center);
|
||||
|
||||
for (int i = 0; i < vec->dim; i++)
|
||||
vec->x[i] /= norm;
|
||||
}
|
||||
else if (type == IVFFLAT_TYPE_HALFVEC)
|
||||
{
|
||||
HalfVector *vec = DatumGetHalfVector(value);
|
||||
|
||||
for (int i = 0; i < vec->dim; i++)
|
||||
vec->x[i] = Float4ToHalfUnchecked(HalfToFloat4(vec->x[i]) / norm);
|
||||
}
|
||||
else
|
||||
elog(ERROR, "Unsupported type");
|
||||
memcpy(DatumGetPointer(center), DatumGetPointer(newCenter), VARSIZE_ANY(DatumGetPointer(newCenter)));
|
||||
MemoryContextReset(normCtx);
|
||||
}
|
||||
|
||||
MemoryContextSwitchTo(oldCtx);
|
||||
MemoryContextDelete(normCtx);
|
||||
}
|
||||
|
||||
/*
|
||||
@@ -154,6 +148,7 @@ QuickCenters(Relation index, VectorArray samples, VectorArray centers, IvfflatTy
|
||||
int dimensions = centers->dim;
|
||||
Oid collation = index->rd_indcollation[0];
|
||||
FmgrInfo *normprocinfo = IvfflatOptionalProcInfo(index, IVFFLAT_KMEANS_NORM_PROC);
|
||||
FmgrInfo *normalizeprocinfo = IvfflatOptionalProcInfo(index, IVFFLAT_NORMALIZE_PROC);
|
||||
|
||||
/* Copy existing vectors while avoiding duplicates */
|
||||
if (samples->length > 0)
|
||||
@@ -217,12 +212,12 @@ QuickCenters(Relation index, VectorArray samples, VectorArray centers, IvfflatTy
|
||||
else
|
||||
elog(ERROR, "Unsupported type");
|
||||
|
||||
/* Normalize if needed (only needed for random centers) */
|
||||
if (normprocinfo != NULL)
|
||||
ApplyNorm(normprocinfo, collation, center, type);
|
||||
|
||||
centers->length++;
|
||||
}
|
||||
|
||||
/* Fine if existing vectors are normalized twice */
|
||||
if (normprocinfo != NULL)
|
||||
NormCenters(normalizeprocinfo, collation, centers);
|
||||
}
|
||||
|
||||
#ifdef IVFFLAT_MEMORY
|
||||
@@ -246,7 +241,7 @@ ShowMemoryUsage(MemoryContext context, Size estimatedSize)
|
||||
* Compute new centers
|
||||
*/
|
||||
static void
|
||||
ComputeNewCenters(VectorArray samples, VectorArray aggCenters, VectorArray newCenters, int *centerCounts, int *closestCenters, FmgrInfo *normprocinfo, Oid collation, IvfflatType type)
|
||||
ComputeNewCenters(VectorArray samples, VectorArray aggCenters, VectorArray newCenters, int *centerCounts, int *closestCenters, FmgrInfo *normprocinfo, FmgrInfo *normalizeprocinfo, Oid collation, IvfflatType type)
|
||||
{
|
||||
int dimensions = aggCenters->dim;
|
||||
int numCenters = aggCenters->maxlen;
|
||||
@@ -360,14 +355,7 @@ ComputeNewCenters(VectorArray samples, VectorArray aggCenters, VectorArray newCe
|
||||
|
||||
/* Normalize if needed */
|
||||
if (normprocinfo != NULL)
|
||||
{
|
||||
for (int j = 0; j < numCenters; j++)
|
||||
{
|
||||
Datum newCenter = PointerGetDatum(VectorArrayGet(newCenters, j));
|
||||
|
||||
ApplyNorm(normprocinfo, collation, newCenter, type);
|
||||
}
|
||||
}
|
||||
NormCenters(normalizeprocinfo, collation, newCenters);
|
||||
}
|
||||
|
||||
/*
|
||||
@@ -383,6 +371,7 @@ ElkanKmeans(Relation index, VectorArray samples, VectorArray centers, IvfflatTyp
|
||||
{
|
||||
FmgrInfo *procinfo;
|
||||
FmgrInfo *normprocinfo;
|
||||
FmgrInfo *normalizeprocinfo;
|
||||
Oid collation;
|
||||
int dimensions = centers->dim;
|
||||
int numCenters = centers->maxlen;
|
||||
@@ -430,6 +419,7 @@ ElkanKmeans(Relation index, VectorArray samples, VectorArray centers, IvfflatTyp
|
||||
/* Set support functions */
|
||||
procinfo = index_getprocinfo(index, 1, IVFFLAT_KMEANS_DISTANCE_PROC);
|
||||
normprocinfo = IvfflatOptionalProcInfo(index, IVFFLAT_KMEANS_NORM_PROC);
|
||||
normalizeprocinfo = IvfflatOptionalProcInfo(index, IVFFLAT_NORMALIZE_PROC);
|
||||
collation = index->rd_indcollation[0];
|
||||
|
||||
/* Use memory context */
|
||||
@@ -627,7 +617,7 @@ ElkanKmeans(Relation index, VectorArray samples, VectorArray centers, IvfflatTyp
|
||||
}
|
||||
|
||||
/* Step 4: For each center c, let m(c) be mean of all points assigned */
|
||||
ComputeNewCenters(samples, aggCenters, newCenters, centerCounts, closestCenters, normprocinfo, collation, type);
|
||||
ComputeNewCenters(samples, aggCenters, newCenters, centerCounts, closestCenters, normprocinfo, normalizeprocinfo, collation, type);
|
||||
|
||||
/* Step 5 */
|
||||
for (int j = 0; j < numCenters; j++)
|
||||
|
||||
Reference in New Issue
Block a user