Added L2 distance [skip ci]

This commit is contained in:
Andrew Kane
2024-09-23 16:52:07 -07:00
parent 035a31ac91
commit 274e6544d4
5 changed files with 95 additions and 0 deletions

View File

@@ -27,6 +27,14 @@ CREATE TYPE minivec (
STORAGE = external
);
CREATE FUNCTION l2_distance(minivec, minivec) RETURNS float8
AS 'MODULE_PATHNAME', 'minivec_l2_distance' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE;
CREATE OPERATOR <-> (
LEFTARG = minivec, RIGHTARG = minivec, PROCEDURE = l2_distance,
COMMUTATOR = '<->'
);
CREATE FUNCTION array_to_sparsevec(integer[], integer, boolean) RETURNS sparsevec
AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE;

View File

@@ -675,6 +675,18 @@ CREATE TYPE minivec (
STORAGE = external
);
-- minivec functions
CREATE FUNCTION l2_distance(minivec, minivec) RETURNS float8
AS 'MODULE_PATHNAME', 'minivec_l2_distance' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE;
-- minivec operators
CREATE OPERATOR <-> (
LEFTARG = minivec, RIGHTARG = minivec, PROCEDURE = l2_distance,
COMMUTATOR = '<->'
);
-- bit functions
CREATE FUNCTION hamming_distance(bit, bit) RETURNS float8

View File

@@ -17,6 +17,18 @@
#include "utils/numeric.h"
#include "vector.h"
/*
* Ensure same dimensions
*/
static inline void
CheckDims(MiniVector * a, MiniVector * b)
{
if (a->dim != b->dim)
ereport(ERROR,
(errcode(ERRCODE_DATA_EXCEPTION),
errmsg("different minivec dimensions %d and %d", a->dim, b->dim)));
}
/*
* Ensure expected dimensions
*/
@@ -334,3 +346,34 @@ minivec_send(PG_FUNCTION_ARGS)
PG_RETURN_BYTEA_P(pq_endtypsend(&buf));
}
static float
MinivecL2SquaredDistance(int dim, fp8 * ax, fp8 * bx)
{
float distance = 0.0;
/* Auto-vectorized */
for (int i = 0; i < dim; i++)
{
float diff = Fp8ToFloat4(ax[i]) - Fp8ToFloat4(bx[i]);
distance += diff * diff;
}
return distance;
}
/*
* Get the L2 distance between fp8 vectors
*/
FUNCTION_PREFIX PG_FUNCTION_INFO_V1(minivec_l2_distance);
Datum
minivec_l2_distance(PG_FUNCTION_ARGS)
{
MiniVector *a = PG_GETARG_MINIVEC_P(0);
MiniVector *b = PG_GETARG_MINIVEC_P(1);
CheckDims(a, b);
PG_RETURN_FLOAT8(sqrt((double) MinivecL2SquaredDistance(a->dim, a->x, b->x)));
}

View File

@@ -168,3 +168,29 @@ SELECT '{"[1,2,3]"}'::minivec(2)[];
{"[1,2,3]"}
(1 row)
SELECT l2_distance('[0,0]'::minivec, '[3,4]');
l2_distance
-------------
5
(1 row)
SELECT l2_distance('[0,0]'::minivec, '[0,1]');
l2_distance
-------------
1
(1 row)
SELECT l2_distance('[1,2]'::minivec, '[3]');
ERROR: different minivec dimensions 2 and 1
SELECT l2_distance('[1,1,1,1,1,1,1,1,1]'::minivec, '[1,1,1,1,1,1,1,4,5]');
l2_distance
-------------
5
(1 row)
SELECT '[0,0]'::minivec <-> '[3,4]';
?column?
----------
5
(1 row)

View File

@@ -36,3 +36,9 @@ SELECT '[1,2,3]'::minivec(16001);
SELECT unnest('{"[1,2,3]", "[4,5,6]"}'::minivec[]);
SELECT '{"[1,2,3]"}'::minivec(2)[];
SELECT l2_distance('[0,0]'::minivec, '[3,4]');
SELECT l2_distance('[0,0]'::minivec, '[0,1]');
SELECT l2_distance('[1,2]'::minivec, '[3]');
SELECT l2_distance('[1,1,1,1,1,1,1,1,1]'::minivec, '[1,1,1,1,1,1,1,4,5]');
SELECT '[0,0]'::minivec <-> '[3,4]';