From 0ff9f6511a3ebc3bb9e92c6e91b1c59ae250d763 Mon Sep 17 00:00:00 2001 From: Andrew Kane Date: Sat, 27 Apr 2024 23:02:00 -0700 Subject: [PATCH] Use f32 [skip ci] --- src/halfutils.c | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/src/halfutils.c b/src/halfutils.c index 0a604a6..19eadae 100644 --- a/src/halfutils.c +++ b/src/halfutils.c @@ -34,17 +34,16 @@ HalfvecL2SquaredDistanceDefault(int dim, half * ax, half * bx) /* TODO Improve */ #ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC - int count = (dim / 8) * 8; + int count = (dim / 4) * 4; float32x4_t dist = vmovq_n_f32(0); - for (; i < count; i += 8) + for (; i < count; i += 4) { - float16x8_t axs = vld1q_f16((const __fp16 *) (ax + i)); - float16x8_t bxs = vld1q_f16((const __fp16 *) (bx + i)); - float16x8_t diff = vsubq_f16(axs, bxs); + float16x4_t axs = vld1_f16((const __fp16 *) (ax + i)); + float16x4_t bxs = vld1_f16((const __fp16 *) (bx + i)); + float32x4_t diff = vsubq_f32(vcvt_f32_f16(axs), vcvt_f32_f16(bxs)); - dist = vfmlalq_low_f16(dist, diff, diff); - dist = vfmlalq_high_f16(dist, diff, diff); + dist = vfmaq_f32(dist, diff, diff); } distance += vaddvq_f32(dist);