diff --git a/src/ivfkmeans.c b/src/ivfkmeans.c index ee161c0..c2c18f1 100644 --- a/src/ivfkmeans.c +++ b/src/ivfkmeans.c @@ -148,23 +148,25 @@ QuickCenters(Relation index, VectorArray samples, VectorArray centers, IvfflatTy /* Fill remaining with random data */ while (centers->length < centers->maxlen) { + Datum center = PointerGetDatum(VectorArrayGet(centers, centers->length)); + if (type == IVFFLAT_TYPE_VECTOR) { - Vector *vec = (Vector *) VectorArrayGet(centers, centers->length); + Vector *vec = DatumGetVector(center); SET_VARSIZE(vec, VECTOR_SIZE(dimensions)); vec->dim = dimensions; for (int j = 0; j < dimensions; j++) vec->x[j] = RandomDouble(); - - /* Normalize if needed (only needed for random centers) */ - if (normprocinfo != NULL) - ApplyNorm(normprocinfo, collation, PointerGetDatum(vec), type); } else elog(ERROR, "Unsupported type"); + /* Normalize if needed (only needed for random centers) */ + if (normprocinfo != NULL) + ApplyNorm(normprocinfo, collation, center, type); + centers->length++; } } @@ -434,7 +436,8 @@ ElkanKmeans(Relation index, VectorArray samples, VectorArray centers, IvfflatTyp for (int64 j = 0; j < numCenters; j++) { - Vector *vec = (Vector *) VectorArrayGet(newCenters, j); + Datum center = PointerGetDatum(VectorArrayGet(newCenters, j)); + Vector *vec = DatumGetVector(center); if (centerCounts[j] > 0) { @@ -458,7 +461,7 @@ ElkanKmeans(Relation index, VectorArray samples, VectorArray centers, IvfflatTyp /* Normalize if needed */ if (normprocinfo != NULL) - ApplyNorm(normprocinfo, collation, PointerGetDatum(vec), type); + ApplyNorm(normprocinfo, collation, center, type); } /* Step 5 */