Switched to support function for normalizing centers for k-means

This commit is contained in:
Andrew Kane
2024-04-23 15:39:58 -07:00
parent 0da6213a60
commit 9cd789fe06

View File

@@ -89,33 +89,27 @@ InitCenters(Relation index, VectorArray samples, VectorArray centers, float *low
}
/*
* Apply norm to vector
* Norm centers
*/
static inline void
ApplyNorm(FmgrInfo *normprocinfo, Oid collation, Datum value, IvfflatType type)
static void
NormCenters(FmgrInfo *normalizeprocinfo, Oid collation, VectorArray centers)
{
double norm = DatumGetFloat8(FunctionCall1Coll(normprocinfo, collation, value));
MemoryContext normCtx = AllocSetContextCreate(CurrentMemoryContext,
"Ivfflat norm temporary context",
ALLOCSET_DEFAULT_SIZES);
MemoryContext oldCtx = MemoryContextSwitchTo(normCtx);
/* TODO Handle zero norm */
if (norm > 0)
for (int j = 0; j < centers->maxlen; j++)
{
if (type == IVFFLAT_TYPE_VECTOR)
{
Vector *vec = DatumGetVector(value);
Datum center = PointerGetDatum(VectorArrayGet(centers, j));
Datum newCenter = IvfflatNormValue(normalizeprocinfo, collation, center);
for (int i = 0; i < vec->dim; i++)
vec->x[i] /= norm;
}
else if (type == IVFFLAT_TYPE_HALFVEC)
{
HalfVector *vec = DatumGetHalfVector(value);
for (int i = 0; i < vec->dim; i++)
vec->x[i] = Float4ToHalfUnchecked(HalfToFloat4(vec->x[i]) / norm);
}
else
elog(ERROR, "Unsupported type");
memcpy(DatumGetPointer(center), DatumGetPointer(newCenter), VARSIZE_ANY(DatumGetPointer(newCenter)));
MemoryContextReset(normCtx);
}
MemoryContextSwitchTo(oldCtx);
MemoryContextDelete(normCtx);
}
/*
@@ -154,6 +148,7 @@ QuickCenters(Relation index, VectorArray samples, VectorArray centers, IvfflatTy
int dimensions = centers->dim;
Oid collation = index->rd_indcollation[0];
FmgrInfo *normprocinfo = IvfflatOptionalProcInfo(index, IVFFLAT_KMEANS_NORM_PROC);
FmgrInfo *normalizeprocinfo = IvfflatOptionalProcInfo(index, IVFFLAT_NORMALIZE_PROC);
/* Copy existing vectors while avoiding duplicates */
if (samples->length > 0)
@@ -217,12 +212,12 @@ QuickCenters(Relation index, VectorArray samples, VectorArray centers, IvfflatTy
else
elog(ERROR, "Unsupported type");
/* Normalize if needed (only needed for random centers) */
if (normprocinfo != NULL)
ApplyNorm(normprocinfo, collation, center, type);
centers->length++;
}
/* Fine if existing vectors are normalized twice */
if (normprocinfo != NULL)
NormCenters(normalizeprocinfo, collation, centers);
}
#ifdef IVFFLAT_MEMORY
@@ -246,7 +241,7 @@ ShowMemoryUsage(MemoryContext context, Size estimatedSize)
* Compute new centers
*/
static void
ComputeNewCenters(VectorArray samples, VectorArray aggCenters, VectorArray newCenters, int *centerCounts, int *closestCenters, FmgrInfo *normprocinfo, Oid collation, IvfflatType type)
ComputeNewCenters(VectorArray samples, VectorArray aggCenters, VectorArray newCenters, int *centerCounts, int *closestCenters, FmgrInfo *normprocinfo, FmgrInfo *normalizeprocinfo, Oid collation, IvfflatType type)
{
int dimensions = aggCenters->dim;
int numCenters = aggCenters->maxlen;
@@ -360,14 +355,7 @@ ComputeNewCenters(VectorArray samples, VectorArray aggCenters, VectorArray newCe
/* Normalize if needed */
if (normprocinfo != NULL)
{
for (int j = 0; j < numCenters; j++)
{
Datum newCenter = PointerGetDatum(VectorArrayGet(newCenters, j));
ApplyNorm(normprocinfo, collation, newCenter, type);
}
}
NormCenters(normalizeprocinfo, collation, newCenters);
}
/*
@@ -383,6 +371,7 @@ ElkanKmeans(Relation index, VectorArray samples, VectorArray centers, IvfflatTyp
{
FmgrInfo *procinfo;
FmgrInfo *normprocinfo;
FmgrInfo *normalizeprocinfo;
Oid collation;
int dimensions = centers->dim;
int numCenters = centers->maxlen;
@@ -430,6 +419,7 @@ ElkanKmeans(Relation index, VectorArray samples, VectorArray centers, IvfflatTyp
/* Set support functions */
procinfo = index_getprocinfo(index, 1, IVFFLAT_KMEANS_DISTANCE_PROC);
normprocinfo = IvfflatOptionalProcInfo(index, IVFFLAT_KMEANS_NORM_PROC);
normalizeprocinfo = IvfflatOptionalProcInfo(index, IVFFLAT_NORMALIZE_PROC);
collation = index->rd_indcollation[0];
/* Use memory context */
@@ -627,7 +617,7 @@ ElkanKmeans(Relation index, VectorArray samples, VectorArray centers, IvfflatTyp
}
/* Step 4: For each center c, let m(c) be mean of all points assigned */
ComputeNewCenters(samples, aggCenters, newCenters, centerCounts, closestCenters, normprocinfo, collation, type);
ComputeNewCenters(samples, aggCenters, newCenters, centerCounts, closestCenters, normprocinfo, normalizeprocinfo, collation, type);
/* Step 5 */
for (int j = 0; j < numCenters; j++)