diff --git a/src/ivfbuild.c b/src/ivfbuild.c index b6a95e5..cf01334 100644 --- a/src/ivfbuild.c +++ b/src/ivfbuild.c @@ -437,7 +437,7 @@ ComputeCenters(IvfflatBuildState * buildstate) } /* Calculate centers */ - IvfflatBench("k-means", IvfflatKmeans(buildstate->index, buildstate->samples, buildstate->centers)); + IvfflatBench("k-means", IvfflatKmeans(buildstate->index, buildstate->samples, buildstate->centers, buildstate->type)); /* Free samples before we allocate more memory */ VectorArrayFree(buildstate->samples); diff --git a/src/ivfflat.h b/src/ivfflat.h index a3a378a..034a039 100644 --- a/src/ivfflat.h +++ b/src/ivfflat.h @@ -270,7 +270,7 @@ typedef IvfflatScanOpaqueData * IvfflatScanOpaque; /* Methods */ VectorArray VectorArrayInit(int maxlen, int dimensions, Size itemsize); void VectorArrayFree(VectorArray arr); -void IvfflatKmeans(Relation index, VectorArray samples, VectorArray centers); +void IvfflatKmeans(Relation index, VectorArray samples, VectorArray centers, IvfflatType type); FmgrInfo *IvfflatOptionalProcInfo(Relation index, uint16 procnum); IvfflatType IvfflatGetType(Relation index); bool IvfflatNormValue(FmgrInfo *procinfo, Oid collation, Datum *value, IvfflatType type); diff --git a/src/ivfkmeans.c b/src/ivfkmeans.c index f739e7d..bb1734b 100644 --- a/src/ivfkmeans.c +++ b/src/ivfkmeans.c @@ -112,7 +112,7 @@ CompareVectors(const void *a, const void *b) * Quick approach if we have little data */ static void -QuickCenters(Relation index, VectorArray samples, VectorArray centers) +QuickCenters(Relation index, VectorArray samples, VectorArray centers, IvfflatType type) { int dimensions = centers->dim; Oid collation = index->rd_indcollation[0]; @@ -121,7 +121,11 @@ QuickCenters(Relation index, VectorArray samples, VectorArray centers) /* Copy existing vectors while avoiding duplicates */ if (samples->length > 0) { - qsort(samples->items, samples->length, samples->itemsize, CompareVectors); + if (type == IVFFLAT_TYPE_VECTOR) + qsort(samples->items, samples->length, samples->itemsize, CompareVectors); + else + elog(ERROR, "Unsupported type"); + for (int i = 0; i < samples->length; i++) { Datum vec = PointerGetDatum(VectorArrayGet(samples, i)); @@ -137,17 +141,22 @@ QuickCenters(Relation index, VectorArray samples, VectorArray centers) /* Fill remaining with random data */ while (centers->length < centers->maxlen) { - Vector *vec = (Vector *) VectorArrayGet(centers, centers->length); + if (type == IVFFLAT_TYPE_VECTOR) + { + Vector *vec = (Vector *) VectorArrayGet(centers, centers->length); - SET_VARSIZE(vec, VECTOR_SIZE(dimensions)); - vec->dim = dimensions; + SET_VARSIZE(vec, VECTOR_SIZE(dimensions)); + vec->dim = dimensions; - for (int j = 0; j < dimensions; j++) - vec->x[j] = RandomDouble(); + for (int j = 0; j < dimensions; j++) + vec->x[j] = RandomDouble(); - /* Normalize if needed (only needed for random centers) */ - if (normprocinfo != NULL) - ApplyNorm(normprocinfo, collation, vec); + /* Normalize if needed (only needed for random centers) */ + if (normprocinfo != NULL) + ApplyNorm(normprocinfo, collation, vec); + } + else + elog(ERROR, "Unsupported type"); centers->length++; } @@ -179,7 +188,7 @@ ShowMemoryUsage(Size estimatedSize) * https://www.aaai.org/Papers/ICML/2003/ICML03-022.pdf */ static void -ElkanKmeans(Relation index, VectorArray samples, VectorArray centers) +ElkanKmeans(Relation index, VectorArray samples, VectorArray centers, IvfflatType type) { FmgrInfo *procinfo; FmgrInfo *normprocinfo; @@ -483,7 +492,7 @@ ElkanKmeans(Relation index, VectorArray samples, VectorArray centers) * Detect issues with centers */ static void -CheckCenters(Relation index, VectorArray centers) +CheckCenters(Relation index, VectorArray centers, IvfflatType type) { FmgrInfo *normprocinfo; @@ -507,7 +516,11 @@ CheckCenters(Relation index, VectorArray centers) /* Ensure no duplicate centers */ /* Fine to sort in-place */ - qsort(centers->items, centers->length, centers->itemsize, CompareVectors); + if (type == IVFFLAT_TYPE_VECTOR) + qsort(centers->items, centers->length, centers->itemsize, CompareVectors); + else + elog(ERROR, "Unsupported type"); + for (int i = 1; i < centers->length; i++) { if (datumIsEqual(PointerGetDatum(VectorArrayGet(centers, i)), PointerGetDatum(VectorArrayGet(centers, i - 1)), false, -1)) @@ -536,12 +549,12 @@ CheckCenters(Relation index, VectorArray centers) * We use spherical k-means for inner product and cosine */ void -IvfflatKmeans(Relation index, VectorArray samples, VectorArray centers) +IvfflatKmeans(Relation index, VectorArray samples, VectorArray centers, IvfflatType type) { if (samples->length <= centers->maxlen) - QuickCenters(index, samples, centers); + QuickCenters(index, samples, centers, type); else - ElkanKmeans(index, samples, centers); + ElkanKmeans(index, samples, centers, type); - CheckCenters(index, centers); + CheckCenters(index, centers, type); }