diff --git a/sql/vector--0.7.4--0.8.0.sql b/sql/vector--0.7.4--0.8.0.sql index 931cd19..38058c8 100644 --- a/sql/vector--0.7.4--0.8.0.sql +++ b/sql/vector--0.7.4--0.8.0.sql @@ -27,6 +27,14 @@ CREATE TYPE minivec ( STORAGE = external ); +CREATE FUNCTION l2_distance(minivec, minivec) RETURNS float8 + AS 'MODULE_PATHNAME', 'minivec_l2_distance' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + +CREATE OPERATOR <-> ( + LEFTARG = minivec, RIGHTARG = minivec, PROCEDURE = l2_distance, + COMMUTATOR = '<->' +); + CREATE FUNCTION array_to_sparsevec(integer[], integer, boolean) RETURNS sparsevec AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; diff --git a/sql/vector.sql b/sql/vector.sql index 1781266..a3c54f1 100644 --- a/sql/vector.sql +++ b/sql/vector.sql @@ -675,6 +675,18 @@ CREATE TYPE minivec ( STORAGE = external ); +-- minivec functions + +CREATE FUNCTION l2_distance(minivec, minivec) RETURNS float8 + AS 'MODULE_PATHNAME', 'minivec_l2_distance' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + +-- minivec operators + +CREATE OPERATOR <-> ( + LEFTARG = minivec, RIGHTARG = minivec, PROCEDURE = l2_distance, + COMMUTATOR = '<->' +); + -- bit functions CREATE FUNCTION hamming_distance(bit, bit) RETURNS float8 diff --git a/src/minivec.c b/src/minivec.c index 2ea5a59..cdab86d 100644 --- a/src/minivec.c +++ b/src/minivec.c @@ -17,6 +17,18 @@ #include "utils/numeric.h" #include "vector.h" +/* + * Ensure same dimensions + */ +static inline void +CheckDims(MiniVector * a, MiniVector * b) +{ + if (a->dim != b->dim) + ereport(ERROR, + (errcode(ERRCODE_DATA_EXCEPTION), + errmsg("different minivec dimensions %d and %d", a->dim, b->dim))); +} + /* * Ensure expected dimensions */ @@ -334,3 +346,34 @@ minivec_send(PG_FUNCTION_ARGS) PG_RETURN_BYTEA_P(pq_endtypsend(&buf)); } + +static float +MinivecL2SquaredDistance(int dim, fp8 * ax, fp8 * bx) +{ + float distance = 0.0; + + /* Auto-vectorized */ + for (int i = 0; i < dim; i++) + { + float diff = Fp8ToFloat4(ax[i]) - Fp8ToFloat4(bx[i]); + + distance += diff * diff; + } + + return distance; +} + +/* + * Get the L2 distance between fp8 vectors + */ +FUNCTION_PREFIX PG_FUNCTION_INFO_V1(minivec_l2_distance); +Datum +minivec_l2_distance(PG_FUNCTION_ARGS) +{ + MiniVector *a = PG_GETARG_MINIVEC_P(0); + MiniVector *b = PG_GETARG_MINIVEC_P(1); + + CheckDims(a, b); + + PG_RETURN_FLOAT8(sqrt((double) MinivecL2SquaredDistance(a->dim, a->x, b->x))); +} diff --git a/test/expected/minivec.out b/test/expected/minivec.out index 2d28a38..b4560be 100644 --- a/test/expected/minivec.out +++ b/test/expected/minivec.out @@ -168,3 +168,29 @@ SELECT '{"[1,2,3]"}'::minivec(2)[]; {"[1,2,3]"} (1 row) +SELECT l2_distance('[0,0]'::minivec, '[3,4]'); + l2_distance +------------- + 5 +(1 row) + +SELECT l2_distance('[0,0]'::minivec, '[0,1]'); + l2_distance +------------- + 1 +(1 row) + +SELECT l2_distance('[1,2]'::minivec, '[3]'); +ERROR: different minivec dimensions 2 and 1 +SELECT l2_distance('[1,1,1,1,1,1,1,1,1]'::minivec, '[1,1,1,1,1,1,1,4,5]'); + l2_distance +------------- + 5 +(1 row) + +SELECT '[0,0]'::minivec <-> '[3,4]'; + ?column? +---------- + 5 +(1 row) + diff --git a/test/sql/minivec.sql b/test/sql/minivec.sql index 61e11d1..b6a7f3e 100644 --- a/test/sql/minivec.sql +++ b/test/sql/minivec.sql @@ -36,3 +36,9 @@ SELECT '[1,2,3]'::minivec(16001); SELECT unnest('{"[1,2,3]", "[4,5,6]"}'::minivec[]); SELECT '{"[1,2,3]"}'::minivec(2)[]; + +SELECT l2_distance('[0,0]'::minivec, '[3,4]'); +SELECT l2_distance('[0,0]'::minivec, '[0,1]'); +SELECT l2_distance('[1,2]'::minivec, '[3]'); +SELECT l2_distance('[1,1,1,1,1,1,1,1,1]'::minivec, '[1,1,1,1,1,1,1,4,5]'); +SELECT '[0,0]'::minivec <-> '[3,4]';