mirror of
https://github.com/pgvector/pgvector.git
synced 2026-06-06 05:51:21 +08:00
Use f32 [skip ci]
This commit is contained in:
@@ -34,17 +34,16 @@ HalfvecL2SquaredDistanceDefault(int dim, half * ax, half * bx)
|
||||
|
||||
/* TODO Improve */
|
||||
#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
|
||||
int count = (dim / 8) * 8;
|
||||
int count = (dim / 4) * 4;
|
||||
float32x4_t dist = vmovq_n_f32(0);
|
||||
|
||||
for (; i < count; i += 8)
|
||||
for (; i < count; i += 4)
|
||||
{
|
||||
float16x8_t axs = vld1q_f16((const __fp16 *) (ax + i));
|
||||
float16x8_t bxs = vld1q_f16((const __fp16 *) (bx + i));
|
||||
float16x8_t diff = vsubq_f16(axs, bxs);
|
||||
float16x4_t axs = vld1_f16((const __fp16 *) (ax + i));
|
||||
float16x4_t bxs = vld1_f16((const __fp16 *) (bx + i));
|
||||
float32x4_t diff = vsubq_f32(vcvt_f32_f16(axs), vcvt_f32_f16(bxs));
|
||||
|
||||
dist = vfmlalq_low_f16(dist, diff, diff);
|
||||
dist = vfmlalq_high_f16(dist, diff, diff);
|
||||
dist = vfmaq_f32(dist, diff, diff);
|
||||
}
|
||||
|
||||
distance += vaddvq_f32(dist);
|
||||
|
||||
Reference in New Issue
Block a user