From 4f6c4850d9d33f58eba38d6c3f1eb4bb276c2109 Mon Sep 17 00:00:00 2001 From: Andrew Kane Date: Sun, 14 Apr 2024 22:59:28 -0700 Subject: [PATCH] Added l1_distance function for sparsevec [skip ci] --- sql/vector--0.6.2--0.7.0.sql | 3 ++ sql/vector.sql | 3 ++ src/sparsevec.c | 49 +++++++++++++++++++++++++++ test/expected/sparsevec_functions.out | 20 +++++++++++ test/sql/sparsevec_functions.sql | 5 +++ 5 files changed, 80 insertions(+) diff --git a/sql/vector--0.6.2--0.7.0.sql b/sql/vector--0.6.2--0.7.0.sql index b2db93e..220d331 100644 --- a/sql/vector--0.6.2--0.7.0.sql +++ b/sql/vector--0.6.2--0.7.0.sql @@ -344,6 +344,9 @@ CREATE FUNCTION inner_product(sparsevec, sparsevec) RETURNS float8 CREATE FUNCTION cosine_distance(sparsevec, sparsevec) RETURNS float8 AS 'MODULE_PATHNAME', 'sparsevec_cosine_distance' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; +CREATE FUNCTION l1_distance(sparsevec, sparsevec) RETURNS float8 + AS 'MODULE_PATHNAME', 'sparsevec_l1_distance' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + CREATE FUNCTION l2_norm(sparsevec) RETURNS float8 AS 'MODULE_PATHNAME', 'sparsevec_l2_norm' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; diff --git a/sql/vector.sql b/sql/vector.sql index 69c3fc0..efa53b6 100644 --- a/sql/vector.sql +++ b/sql/vector.sql @@ -653,6 +653,9 @@ CREATE FUNCTION inner_product(sparsevec, sparsevec) RETURNS float8 CREATE FUNCTION cosine_distance(sparsevec, sparsevec) RETURNS float8 AS 'MODULE_PATHNAME', 'sparsevec_cosine_distance' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; +CREATE FUNCTION l1_distance(sparsevec, sparsevec) RETURNS float8 + AS 'MODULE_PATHNAME', 'sparsevec_l1_distance' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + CREATE FUNCTION l2_norm(sparsevec) RETURNS float8 AS 'MODULE_PATHNAME', 'sparsevec_l2_norm' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; diff --git a/src/sparsevec.c b/src/sparsevec.c index 407195d..6ebdfba 100644 --- a/src/sparsevec.c +++ b/src/sparsevec.c @@ -781,6 +781,55 @@ sparsevec_cosine_distance(PG_FUNCTION_ARGS) PG_RETURN_FLOAT8(1.0 - similarity); } +/* + * Get the L1 distance between two sparse vectors + */ +PGDLLEXPORT PG_FUNCTION_INFO_V1(sparsevec_l1_distance); +Datum +sparsevec_l1_distance(PG_FUNCTION_ARGS) +{ + SparseVector *a = PG_GETARG_SPARSEVEC_P(0); + SparseVector *b = PG_GETARG_SPARSEVEC_P(1); + float *ax = SPARSEVEC_VALUES(a); + float *bx = SPARSEVEC_VALUES(b); + double distance = 0.0; + int bpos = 0; + + CheckDims(a, b); + + for (int i = 0; i < a->nnz; i++) + { + int ai = a->indices[i]; + int bi = -1; + + for (int j = bpos; j < b->nnz; j++) + { + bi = b->indices[j]; + + if (ai == bi) + distance += fabs(ax[i] - bx[j]); + else if (ai > bi) + distance += fabs(bx[j]); + + /* Update start for next iteration */ + if (ai >= bi) + bpos = j + 1; + + /* Found or passed it */ + if (bi >= ai) + break; + } + + if (ai != bi) + distance += fabs(ax[i]); + } + + for (int j = bpos; j < b->nnz; j++) + distance += fabs(bx[j]); + + PG_RETURN_FLOAT8(distance); +} + /* * Get the L2 norm of a sparse vector */ diff --git a/test/expected/sparsevec_functions.out b/test/expected/sparsevec_functions.out index 8f3dab6..96205ce 100644 --- a/test/expected/sparsevec_functions.out +++ b/test/expected/sparsevec_functions.out @@ -198,3 +198,23 @@ SELECT cosine_distance('{}/1'::sparsevec, '{}/1'); SELECT cosine_distance('{1:2}/2'::sparsevec, '{1:1}/3'); ERROR: different sparsevec dimensions 2 and 3 +SELECT l1_distance('{}/2'::sparsevec, '{1:3,2:4}/2'); + l1_distance +------------- + 7 +(1 row) + +SELECT l1_distance('{}/2'::sparsevec, '{2:1}/2'); + l1_distance +------------- + 1 +(1 row) + +SELECT l1_distance('{1:1,2:2}/2'::sparsevec, '{1:3}/1'); +ERROR: different sparsevec dimensions 2 and 1 +SELECT l1_distance('{1:3e38}/1'::sparsevec, '{1:-3e38}/1'); + l1_distance +------------- + Infinity +(1 row) + diff --git a/test/sql/sparsevec_functions.sql b/test/sql/sparsevec_functions.sql index 76549c7..a1e6e42 100644 --- a/test/sql/sparsevec_functions.sql +++ b/test/sql/sparsevec_functions.sql @@ -37,3 +37,8 @@ SELECT cosine_distance('{1:1,2:1}/2'::sparsevec, '{1:-1,2:-1}/2'); SELECT cosine_distance('{1:2}/2'::sparsevec, '{2:2}/2'); SELECT cosine_distance('{}/1'::sparsevec, '{}/1'); SELECT cosine_distance('{1:2}/2'::sparsevec, '{1:1}/3'); + +SELECT l1_distance('{}/2'::sparsevec, '{1:3,2:4}/2'); +SELECT l1_distance('{}/2'::sparsevec, '{2:1}/2'); +SELECT l1_distance('{1:1,2:2}/2'::sparsevec, '{1:3}/1'); +SELECT l1_distance('{1:3e38}/1'::sparsevec, '{1:-3e38}/1');