Use f32 [skip ci]

This commit is contained in:
Andrew Kane
2024-04-27 23:02:00 -07:00
parent 17855c9861
commit 0ff9f6511a

View File

@@ -34,17 +34,16 @@ HalfvecL2SquaredDistanceDefault(int dim, half * ax, half * bx)
/* TODO Improve */
#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
int count = (dim / 8) * 8;
int count = (dim / 4) * 4;
float32x4_t dist = vmovq_n_f32(0);
for (; i < count; i += 8)
for (; i < count; i += 4)
{
float16x8_t axs = vld1q_f16((const __fp16 *) (ax + i));
float16x8_t bxs = vld1q_f16((const __fp16 *) (bx + i));
float16x8_t diff = vsubq_f16(axs, bxs);
float16x4_t axs = vld1_f16((const __fp16 *) (ax + i));
float16x4_t bxs = vld1_f16((const __fp16 *) (bx + i));
float32x4_t diff = vsubq_f32(vcvt_f32_f16(axs), vcvt_f32_f16(bxs));
dist = vfmlalq_low_f16(dist, diff, diff);
dist = vfmlalq_high_f16(dist, diff, diff);
dist = vfmaq_f32(dist, diff, diff);
}
distance += vaddvq_f32(dist);