DRY halfvec distance functions

This commit is contained in:
Andrew Kane
2024-04-07 19:03:20 -07:00
parent d861a0304e
commit 3bd67fef54

View File

@@ -801,20 +801,15 @@ vector_to_halfvec(PG_FUNCTION_ARGS)
}
/*
* Get the L2 distance between half vectors
* Get the L2 squared distance between half vectors
*/
PGDLLEXPORT PG_FUNCTION_INFO_V1(halfvec_l2_distance);
Datum
halfvec_l2_distance(PG_FUNCTION_ARGS)
static double
l2_distance_squared_internal(HalfVector * a, HalfVector * b)
{
HalfVector *a = PG_GETARG_HALFVEC_P(0);
HalfVector *b = PG_GETARG_HALFVEC_P(1);
half *ax = a->x;
half *bx = b->x;
float distance = 0.0;
CheckDims(a, b);
/* Auto-vectorized */
for (int i = 0; i < a->dim; i++)
{
@@ -823,7 +818,22 @@ halfvec_l2_distance(PG_FUNCTION_ARGS)
distance += diff * diff;
}
PG_RETURN_FLOAT8(sqrt((double) distance));
return (double) distance;
}
/*
* Get the L2 distance between half vectors
*/
PGDLLEXPORT PG_FUNCTION_INFO_V1(halfvec_l2_distance);
Datum
halfvec_l2_distance(PG_FUNCTION_ARGS)
{
HalfVector *a = PG_GETARG_HALFVEC_P(0);
HalfVector *b = PG_GETARG_HALFVEC_P(1);
CheckDims(a, b);
PG_RETURN_FLOAT8(sqrt(l2_distance_squared_internal(a, b)));
}
/*
@@ -835,21 +845,27 @@ halfvec_l2_squared_distance(PG_FUNCTION_ARGS)
{
HalfVector *a = PG_GETARG_HALFVEC_P(0);
HalfVector *b = PG_GETARG_HALFVEC_P(1);
CheckDims(a, b);
PG_RETURN_FLOAT8(l2_distance_squared_internal(a, b));
}
/*
* Get the inner product of two half vectors
*/
static double
inner_product_internal(HalfVector * a, HalfVector * b)
{
half *ax = a->x;
half *bx = b->x;
float distance = 0.0;
CheckDims(a, b);
/* Auto-vectorized */
for (int i = 0; i < a->dim; i++)
{
float diff = HalfToFloat4(ax[i]) - HalfToFloat4(bx[i]);
distance += HalfToFloat4(ax[i]) * HalfToFloat4(bx[i]);
distance += diff * diff;
}
PG_RETURN_FLOAT8((double) distance);
return (double) distance;
}
/*
@@ -861,17 +877,10 @@ halfvec_inner_product(PG_FUNCTION_ARGS)
{
HalfVector *a = PG_GETARG_HALFVEC_P(0);
HalfVector *b = PG_GETARG_HALFVEC_P(1);
half *ax = a->x;
half *bx = b->x;
float distance = 0.0;
CheckDims(a, b);
/* Auto-vectorized */
for (int i = 0; i < a->dim; i++)
distance += HalfToFloat4(ax[i]) * HalfToFloat4(bx[i]);
PG_RETURN_FLOAT8((double) distance);
PG_RETURN_FLOAT8(inner_product_internal(a, b));
}
/*
@@ -883,17 +892,10 @@ halfvec_negative_inner_product(PG_FUNCTION_ARGS)
{
HalfVector *a = PG_GETARG_HALFVEC_P(0);
HalfVector *b = PG_GETARG_HALFVEC_P(1);
half *ax = a->x;
half *bx = b->x;
float distance = 0.0;
CheckDims(a, b);
/* Auto-vectorized */
for (int i = 0; i < a->dim; i++)
distance += HalfToFloat4(ax[i]) * HalfToFloat4(bx[i]);
PG_RETURN_FLOAT8((double) distance * -1);
PG_RETURN_FLOAT8(-inner_product_internal(a, b));
}
/*