From 3bd67fef54ebb5d4ab251438eaa08551f41b966f Mon Sep 17 00:00:00 2001 From: Andrew Kane Date: Sun, 7 Apr 2024 19:03:20 -0700 Subject: [PATCH] DRY halfvec distance functions --- src/halfvec.c | 68 ++++++++++++++++++++++++++------------------------- 1 file changed, 35 insertions(+), 33 deletions(-) diff --git a/src/halfvec.c b/src/halfvec.c index 0d6d0c5..35cc0d0 100644 --- a/src/halfvec.c +++ b/src/halfvec.c @@ -801,20 +801,15 @@ vector_to_halfvec(PG_FUNCTION_ARGS) } /* - * Get the L2 distance between half vectors + * Get the L2 squared distance between half vectors */ -PGDLLEXPORT PG_FUNCTION_INFO_V1(halfvec_l2_distance); -Datum -halfvec_l2_distance(PG_FUNCTION_ARGS) +static double +l2_distance_squared_internal(HalfVector * a, HalfVector * b) { - HalfVector *a = PG_GETARG_HALFVEC_P(0); - HalfVector *b = PG_GETARG_HALFVEC_P(1); half *ax = a->x; half *bx = b->x; float distance = 0.0; - CheckDims(a, b); - /* Auto-vectorized */ for (int i = 0; i < a->dim; i++) { @@ -823,7 +818,22 @@ halfvec_l2_distance(PG_FUNCTION_ARGS) distance += diff * diff; } - PG_RETURN_FLOAT8(sqrt((double) distance)); + return (double) distance; +} + +/* + * Get the L2 distance between half vectors + */ +PGDLLEXPORT PG_FUNCTION_INFO_V1(halfvec_l2_distance); +Datum +halfvec_l2_distance(PG_FUNCTION_ARGS) +{ + HalfVector *a = PG_GETARG_HALFVEC_P(0); + HalfVector *b = PG_GETARG_HALFVEC_P(1); + + CheckDims(a, b); + + PG_RETURN_FLOAT8(sqrt(l2_distance_squared_internal(a, b))); } /* @@ -835,21 +845,27 @@ halfvec_l2_squared_distance(PG_FUNCTION_ARGS) { HalfVector *a = PG_GETARG_HALFVEC_P(0); HalfVector *b = PG_GETARG_HALFVEC_P(1); + + CheckDims(a, b); + + PG_RETURN_FLOAT8(l2_distance_squared_internal(a, b)); +} + +/* + * Get the inner product of two half vectors + */ +static double +inner_product_internal(HalfVector * a, HalfVector * b) +{ half *ax = a->x; half *bx = b->x; float distance = 0.0; - CheckDims(a, b); - /* Auto-vectorized */ for (int i = 0; i < a->dim; i++) - { - float diff = HalfToFloat4(ax[i]) - HalfToFloat4(bx[i]); + distance += HalfToFloat4(ax[i]) * HalfToFloat4(bx[i]); - distance += diff * diff; - } - - PG_RETURN_FLOAT8((double) distance); + return (double) distance; } /* @@ -861,17 +877,10 @@ halfvec_inner_product(PG_FUNCTION_ARGS) { HalfVector *a = PG_GETARG_HALFVEC_P(0); HalfVector *b = PG_GETARG_HALFVEC_P(1); - half *ax = a->x; - half *bx = b->x; - float distance = 0.0; CheckDims(a, b); - /* Auto-vectorized */ - for (int i = 0; i < a->dim; i++) - distance += HalfToFloat4(ax[i]) * HalfToFloat4(bx[i]); - - PG_RETURN_FLOAT8((double) distance); + PG_RETURN_FLOAT8(inner_product_internal(a, b)); } /* @@ -883,17 +892,10 @@ halfvec_negative_inner_product(PG_FUNCTION_ARGS) { HalfVector *a = PG_GETARG_HALFVEC_P(0); HalfVector *b = PG_GETARG_HALFVEC_P(1); - half *ax = a->x; - half *bx = b->x; - float distance = 0.0; CheckDims(a, b); - /* Auto-vectorized */ - for (int i = 0; i < a->dim; i++) - distance += HalfToFloat4(ax[i]) * HalfToFloat4(bx[i]); - - PG_RETURN_FLOAT8((double) distance * -1); + PG_RETURN_FLOAT8(-inner_product_internal(a, b)); } /*