From 66a29dbdf3e29a8ff287f08f5af498584fc7aa39 Mon Sep 17 00:00:00 2001 From: Andrew Kane Date: Thu, 11 Apr 2024 16:50:21 -0700 Subject: [PATCH] Switched to Datum for ApplyNorm [skip ci] --- src/ivfkmeans.c | 21 ++++++++++++++------- 1 file changed, 14 insertions(+), 7 deletions(-) diff --git a/src/ivfkmeans.c b/src/ivfkmeans.c index bb1734b..ee161c0 100644 --- a/src/ivfkmeans.c +++ b/src/ivfkmeans.c @@ -87,15 +87,22 @@ InitCenters(Relation index, VectorArray samples, VectorArray centers, float *low * Apply norm to vector */ static inline void -ApplyNorm(FmgrInfo *normprocinfo, Oid collation, Vector * vec) +ApplyNorm(FmgrInfo *normprocinfo, Oid collation, Datum value, IvfflatType type) { - double norm = DatumGetFloat8(FunctionCall1Coll(normprocinfo, collation, PointerGetDatum(vec))); + double norm = DatumGetFloat8(FunctionCall1Coll(normprocinfo, collation, value)); /* TODO Handle zero norm */ if (norm > 0) { - for (int i = 0; i < vec->dim; i++) - vec->x[i] /= norm; + if (type == IVFFLAT_TYPE_VECTOR) + { + Vector *vec = DatumGetVector(value); + + for (int i = 0; i < vec->dim; i++) + vec->x[i] /= norm; + } + else + elog(ERROR, "Unsupported type"); } } @@ -128,7 +135,7 @@ QuickCenters(Relation index, VectorArray samples, VectorArray centers, IvfflatTy for (int i = 0; i < samples->length; i++) { - Datum vec = PointerGetDatum(VectorArrayGet(samples, i)); + Datum vec = PointerGetDatum(VectorArrayGet(samples, i)); if (i == 0 || !datumIsEqual(vec, PointerGetDatum(VectorArrayGet(samples, i - 1)), false, -1)) { @@ -153,7 +160,7 @@ QuickCenters(Relation index, VectorArray samples, VectorArray centers, IvfflatTy /* Normalize if needed (only needed for random centers) */ if (normprocinfo != NULL) - ApplyNorm(normprocinfo, collation, vec); + ApplyNorm(normprocinfo, collation, PointerGetDatum(vec), type); } else elog(ERROR, "Unsupported type"); @@ -451,7 +458,7 @@ ElkanKmeans(Relation index, VectorArray samples, VectorArray centers, IvfflatTyp /* Normalize if needed */ if (normprocinfo != NULL) - ApplyNorm(normprocinfo, collation, vec); + ApplyNorm(normprocinfo, collation, PointerGetDatum(vec), type); } /* Step 5 */