diff --git a/src/halfvec.c b/src/halfvec.c index ee902e1..6a41295 100644 --- a/src/halfvec.c +++ b/src/halfvec.c @@ -94,7 +94,7 @@ pq_sendhalf(StringInfo buf, half h) /* * Convert a half to a float4 */ -static float +float HalfToFloat4(half num) { #ifdef FLT16_SUPPORT @@ -245,7 +245,7 @@ Float4ToHalfUnchecked(float num) /* * Convert a float4 to a half */ -static half +half Float4ToHalf(float num) { half result = Float4ToHalfUnchecked(num); diff --git a/src/halfvec.h b/src/halfvec.h index 21ef042..3d4e821 100644 --- a/src/halfvec.h +++ b/src/halfvec.h @@ -34,5 +34,7 @@ typedef struct HalfVector } HalfVector; HalfVector *InitHalfVector(int dim); +float HalfToFloat4(half num); +half Float4ToHalf(float num); #endif diff --git a/src/hnswutils.c b/src/hnswutils.c index 74fd3c3..0b8c7c5 100644 --- a/src/hnswutils.c +++ b/src/hnswutils.c @@ -199,7 +199,7 @@ HnswNormValue(FmgrInfo *procinfo, Oid collation, Datum *value, int type) HalfVector *result = InitHalfVector(v->dim); for (int i = 0; i < v->dim; i++) - result->x[i] = v->x[i] / norm; + result->x[i] = Float4ToHalf(HalfToFloat4(v->x[i]) / norm); *value = PointerGetDatum(result); }