Mark type-specific code [skip ci]

This commit is contained in:
Andrew Kane
2024-04-11 16:44:10 -07:00
parent 000cc13c29
commit 1c26da6ef5
3 changed files with 32 additions and 19 deletions

View File

@@ -437,7 +437,7 @@ ComputeCenters(IvfflatBuildState * buildstate)
}
/* Calculate centers */
IvfflatBench("k-means", IvfflatKmeans(buildstate->index, buildstate->samples, buildstate->centers));
IvfflatBench("k-means", IvfflatKmeans(buildstate->index, buildstate->samples, buildstate->centers, buildstate->type));
/* Free samples before we allocate more memory */
VectorArrayFree(buildstate->samples);

View File

@@ -270,7 +270,7 @@ typedef IvfflatScanOpaqueData * IvfflatScanOpaque;
/* Methods */
VectorArray VectorArrayInit(int maxlen, int dimensions, Size itemsize);
void VectorArrayFree(VectorArray arr);
void IvfflatKmeans(Relation index, VectorArray samples, VectorArray centers);
void IvfflatKmeans(Relation index, VectorArray samples, VectorArray centers, IvfflatType type);
FmgrInfo *IvfflatOptionalProcInfo(Relation index, uint16 procnum);
IvfflatType IvfflatGetType(Relation index);
bool IvfflatNormValue(FmgrInfo *procinfo, Oid collation, Datum *value, IvfflatType type);

View File

@@ -112,7 +112,7 @@ CompareVectors(const void *a, const void *b)
* Quick approach if we have little data
*/
static void
QuickCenters(Relation index, VectorArray samples, VectorArray centers)
QuickCenters(Relation index, VectorArray samples, VectorArray centers, IvfflatType type)
{
int dimensions = centers->dim;
Oid collation = index->rd_indcollation[0];
@@ -121,7 +121,11 @@ QuickCenters(Relation index, VectorArray samples, VectorArray centers)
/* Copy existing vectors while avoiding duplicates */
if (samples->length > 0)
{
qsort(samples->items, samples->length, samples->itemsize, CompareVectors);
if (type == IVFFLAT_TYPE_VECTOR)
qsort(samples->items, samples->length, samples->itemsize, CompareVectors);
else
elog(ERROR, "Unsupported type");
for (int i = 0; i < samples->length; i++)
{
Datum vec = PointerGetDatum(VectorArrayGet(samples, i));
@@ -137,17 +141,22 @@ QuickCenters(Relation index, VectorArray samples, VectorArray centers)
/* Fill remaining with random data */
while (centers->length < centers->maxlen)
{
Vector *vec = (Vector *) VectorArrayGet(centers, centers->length);
if (type == IVFFLAT_TYPE_VECTOR)
{
Vector *vec = (Vector *) VectorArrayGet(centers, centers->length);
SET_VARSIZE(vec, VECTOR_SIZE(dimensions));
vec->dim = dimensions;
SET_VARSIZE(vec, VECTOR_SIZE(dimensions));
vec->dim = dimensions;
for (int j = 0; j < dimensions; j++)
vec->x[j] = RandomDouble();
for (int j = 0; j < dimensions; j++)
vec->x[j] = RandomDouble();
/* Normalize if needed (only needed for random centers) */
if (normprocinfo != NULL)
ApplyNorm(normprocinfo, collation, vec);
/* Normalize if needed (only needed for random centers) */
if (normprocinfo != NULL)
ApplyNorm(normprocinfo, collation, vec);
}
else
elog(ERROR, "Unsupported type");
centers->length++;
}
@@ -179,7 +188,7 @@ ShowMemoryUsage(Size estimatedSize)
* https://www.aaai.org/Papers/ICML/2003/ICML03-022.pdf
*/
static void
ElkanKmeans(Relation index, VectorArray samples, VectorArray centers)
ElkanKmeans(Relation index, VectorArray samples, VectorArray centers, IvfflatType type)
{
FmgrInfo *procinfo;
FmgrInfo *normprocinfo;
@@ -483,7 +492,7 @@ ElkanKmeans(Relation index, VectorArray samples, VectorArray centers)
* Detect issues with centers
*/
static void
CheckCenters(Relation index, VectorArray centers)
CheckCenters(Relation index, VectorArray centers, IvfflatType type)
{
FmgrInfo *normprocinfo;
@@ -507,7 +516,11 @@ CheckCenters(Relation index, VectorArray centers)
/* Ensure no duplicate centers */
/* Fine to sort in-place */
qsort(centers->items, centers->length, centers->itemsize, CompareVectors);
if (type == IVFFLAT_TYPE_VECTOR)
qsort(centers->items, centers->length, centers->itemsize, CompareVectors);
else
elog(ERROR, "Unsupported type");
for (int i = 1; i < centers->length; i++)
{
if (datumIsEqual(PointerGetDatum(VectorArrayGet(centers, i)), PointerGetDatum(VectorArrayGet(centers, i - 1)), false, -1))
@@ -536,12 +549,12 @@ CheckCenters(Relation index, VectorArray centers)
* We use spherical k-means for inner product and cosine
*/
void
IvfflatKmeans(Relation index, VectorArray samples, VectorArray centers)
IvfflatKmeans(Relation index, VectorArray samples, VectorArray centers, IvfflatType type)
{
if (samples->length <= centers->maxlen)
QuickCenters(index, samples, centers);
QuickCenters(index, samples, centers, type);
else
ElkanKmeans(index, samples, centers);
ElkanKmeans(index, samples, centers, type);
CheckCenters(index, centers);
CheckCenters(index, centers, type);
}