diff --git a/src/halfutils.c b/src/halfutils.c index 6f499dd..0a604a6 100644 --- a/src/halfutils.c +++ b/src/halfutils.c @@ -1,5 +1,7 @@ #include "postgres.h" +#include + #include "halfutils.h" #include "halfvec.h" @@ -28,9 +30,28 @@ static float HalfvecL2SquaredDistanceDefault(int dim, half * ax, half * bx) { float distance = 0.0; + int i = 0; + +/* TODO Improve */ +#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + int count = (dim / 8) * 8; + float32x4_t dist = vmovq_n_f32(0); + + for (; i < count; i += 8) + { + float16x8_t axs = vld1q_f16((const __fp16 *) (ax + i)); + float16x8_t bxs = vld1q_f16((const __fp16 *) (bx + i)); + float16x8_t diff = vsubq_f16(axs, bxs); + + dist = vfmlalq_low_f16(dist, diff, diff); + dist = vfmlalq_high_f16(dist, diff, diff); + } + + distance += vaddvq_f32(dist); +#endif /* Auto-vectorized */ - for (int i = 0; i < dim; i++) + for (; i < dim; i++) { float diff = HalfToFloat4(ax[i]) - HalfToFloat4(bx[i]);