Improved QuickCenters [skip ci]

This commit is contained in:
Andrew Kane
2024-04-24 16:38:14 -07:00
parent c4484c90d9
commit 25b98540c9

View File

@@ -163,6 +163,64 @@ SortVectorArray(VectorArray arr, IvfflatType type)
qsort(arr->items, arr->length, arr->itemsize, comp);
}
static void
VectorInitNewCenter(Pointer v, int dimensions)
{
Vector *vec = (Vector *) v;
SET_VARSIZE(vec, VECTOR_SIZE(dimensions));
vec->dim = dimensions;
}
static void
HalfvecInitNewCenter(Pointer v, int dimensions)
{
HalfVector *vec = (HalfVector *) v;
SET_VARSIZE(vec, HALFVEC_SIZE(dimensions));
vec->dim = dimensions;
}
static void
BitInitNewCenter(Pointer v, int dimensions)
{
VarBit *vec = (VarBit *) v;
SET_VARSIZE(vec, VARBITTOTALLEN(dimensions));
VARBITLEN(vec) = dimensions;
}
static void
VectorSetNewCenter(Pointer v, float *x)
{
Vector *newCenter = (Vector *) v;
for (int k = 0; k < newCenter->dim; k++)
newCenter->x[k] = x[k];
}
static void
HalfvecSetNewCenter(Pointer v, float *x)
{
HalfVector *newCenter = (HalfVector *) v;
for (int k = 0; k < newCenter->dim; k++)
newCenter->x[k] = Float4ToHalfUnchecked(x[k]);
}
static void
BitSetNewCenter(Pointer v, float *x)
{
VarBit *newCenter = (VarBit *) v;
unsigned char *nx = VARBITS(newCenter);
for (uint32 k = 0; k < VARBITBYTES(newCenter); k++)
nx[k] = 0;
for (int k = 0; k < VARBITLEN(newCenter); k++)
nx[k / 8] |= (x[k] > 0.5 ? 1 : 0) << (7 - (k % 8));
}
/*
* Quick approach if we have little data
*/
@@ -173,6 +231,27 @@ QuickCenters(Relation index, VectorArray samples, VectorArray centers, IvfflatTy
Oid collation = index->rd_indcollation[0];
FmgrInfo *normprocinfo = IvfflatOptionalProcInfo(index, IVFFLAT_KMEANS_NORM_PROC);
FmgrInfo *normalizeprocinfo = IvfflatOptionalProcInfo(index, IVFFLAT_NORMALIZE_PROC);
void (*initNewCenter) (Pointer v, int dimensions);
void (*setNewCenter) (Pointer v, float *x);
float *x = (float *) palloc(sizeof(float) * dimensions);
if (type == IVFFLAT_TYPE_VECTOR)
{
initNewCenter = VectorInitNewCenter;
setNewCenter = VectorSetNewCenter;
}
else if (type == IVFFLAT_TYPE_HALFVEC)
{
initNewCenter = HalfvecInitNewCenter;
setNewCenter = HalfvecSetNewCenter;
}
else if (type == IVFFLAT_TYPE_BIT)
{
initNewCenter = BitInitNewCenter;
setNewCenter = BitSetNewCenter;
}
else
elog(ERROR, "Unsupported type");
/* Copy existing vectors while avoiding duplicates */
if (samples->length > 0)
@@ -194,40 +273,13 @@ QuickCenters(Relation index, VectorArray samples, VectorArray centers, IvfflatTy
/* Fill remaining with random data */
while (centers->length < centers->maxlen)
{
Datum center = PointerGetDatum(VectorArrayGet(centers, centers->length));
Pointer center = VectorArrayGet(centers, centers->length);
if (type == IVFFLAT_TYPE_VECTOR)
{
Vector *vec = DatumGetVector(center);
for (int i = 0; i < dimensions; i++)
x[i] = (float) RandomDouble();
SET_VARSIZE(vec, VECTOR_SIZE(dimensions));
vec->dim = dimensions;
for (int j = 0; j < dimensions; j++)
vec->x[j] = RandomDouble();
}
else if (type == IVFFLAT_TYPE_HALFVEC)
{
HalfVector *vec = DatumGetHalfVector(center);
SET_VARSIZE(vec, HALFVEC_SIZE(dimensions));
vec->dim = dimensions;
for (int j = 0; j < dimensions; j++)
vec->x[j] = Float4ToHalfUnchecked((float) RandomDouble());
}
else if (type == IVFFLAT_TYPE_BIT)
{
VarBit *vec = DatumGetVarBitP(center);
SET_VARSIZE(vec, VARBITTOTALLEN(dimensions));
VARBITLEN(vec) = dimensions;
for (int j = 0; j < dimensions; j++)
VARBITS(vec)[j / dimensions] |= (RandomDouble() > 0.5 ? 1 : 0) << (7 - (j % 8));
}
else
elog(ERROR, "Unsupported type");
initNewCenter(center, dimensions);
setNewCenter(center, x);
centers->length++;
}
@@ -235,6 +287,8 @@ QuickCenters(Relation index, VectorArray samples, VectorArray centers, IvfflatTy
/* Fine if existing vectors are normalized twice */
if (normprocinfo != NULL)
NormCenters(normalizeprocinfo, collation, centers);
pfree(x);
}
#ifdef IVFFLAT_MEMORY
@@ -306,28 +360,6 @@ SumCenters(VectorArray samples, VectorArray aggCenters, int *closestCenters, Ivf
}
}
static void
HalfvecSetNewCenter(Pointer v, float *x)
{
HalfVector *newCenter = (HalfVector *) v;
for (int k = 0; k < newCenter->dim; k++)
newCenter->x[k] = Float4ToHalfUnchecked(x[k]);
}
static void
BitSetNewCenter(Pointer v, float *x)
{
VarBit *newCenter = (VarBit *) v;
unsigned char *nx = VARBITS(newCenter);
for (uint32 k = 0; k < VARBITBYTES(newCenter); k++)
nx[k] = 0;
for (int k = 0; k < VARBITLEN(newCenter); k++)
nx[k / 8] |= (x[k] > 0.5) << (7 - (k % 8));
}
/*
* Set new centers
*/