From 434f3f5e8826d7c00a019377d9264971727f8ef9 Mon Sep 17 00:00:00 2001 From: Andrew Kane Date: Mon, 8 Apr 2024 16:41:50 -0700 Subject: [PATCH] DRY vector distance functions --- src/vector.c | 70 +++++++++++++++++++++++----------------------------- 1 file changed, 31 insertions(+), 39 deletions(-) diff --git a/src/vector.c b/src/vector.c index 7032b3e..54e6a3c 100644 --- a/src/vector.c +++ b/src/vector.c @@ -567,6 +567,22 @@ halfvec_to_vector(PG_FUNCTION_ARGS) PG_RETURN_POINTER(result); } +static float +VectorL2SquaredDistance(int dim, float *ax, float *bx) +{ + float distance = 0.0; + + /* Auto-vectorized */ + for (int i = 0; i < dim; i++) + { + float diff = ax[i] - bx[i]; + + distance += diff * diff; + } + + return distance; +} + /* * Get the L2 distance between vectors */ @@ -576,21 +592,10 @@ l2_distance(PG_FUNCTION_ARGS) { Vector *a = PG_GETARG_VECTOR_P(0); Vector *b = PG_GETARG_VECTOR_P(1); - float *ax = a->x; - float *bx = b->x; - float distance = 0.0; - float diff; CheckDims(a, b); - /* Auto-vectorized */ - for (int i = 0; i < a->dim; i++) - { - diff = ax[i] - bx[i]; - distance += diff * diff; - } - - PG_RETURN_FLOAT8(sqrt((double) distance)); + PG_RETURN_FLOAT8(sqrt((double) VectorL2SquaredDistance(a->dim, a->x, b->x))); } /* @@ -603,21 +608,22 @@ vector_l2_squared_distance(PG_FUNCTION_ARGS) { Vector *a = PG_GETARG_VECTOR_P(0); Vector *b = PG_GETARG_VECTOR_P(1); - float *ax = a->x; - float *bx = b->x; - float distance = 0.0; - float diff; CheckDims(a, b); - /* Auto-vectorized */ - for (int i = 0; i < a->dim; i++) - { - diff = ax[i] - bx[i]; - distance += diff * diff; - } + PG_RETURN_FLOAT8((double) VectorL2SquaredDistance(a->dim, a->x, b->x)); +} - PG_RETURN_FLOAT8((double) distance); +static float +VectorInnerProduct(int dim, float *ax, float *bx) +{ + float distance = 0.0; + + /* Auto-vectorized */ + for (int i = 0; i < dim; i++) + distance += ax[i] * bx[i]; + + return distance; } /* @@ -629,17 +635,10 @@ inner_product(PG_FUNCTION_ARGS) { Vector *a = PG_GETARG_VECTOR_P(0); Vector *b = PG_GETARG_VECTOR_P(1); - float *ax = a->x; - float *bx = b->x; - float distance = 0.0; CheckDims(a, b); - /* Auto-vectorized */ - for (int i = 0; i < a->dim; i++) - distance += ax[i] * bx[i]; - - PG_RETURN_FLOAT8((double) distance); + PG_RETURN_FLOAT8((double) VectorInnerProduct(a->dim, a->x, b->x)); } /* @@ -651,17 +650,10 @@ vector_negative_inner_product(PG_FUNCTION_ARGS) { Vector *a = PG_GETARG_VECTOR_P(0); Vector *b = PG_GETARG_VECTOR_P(1); - float *ax = a->x; - float *bx = b->x; - float distance = 0.0; CheckDims(a, b); - /* Auto-vectorized */ - for (int i = 0; i < a->dim; i++) - distance += ax[i] * bx[i]; - - PG_RETURN_FLOAT8((double) distance * -1); + PG_RETURN_FLOAT8((double) -VectorInnerProduct(a->dim, a->x, b->x)); } /*