diff --git a/sql/vector--0.6.2--0.7.0.sql b/sql/vector--0.6.2--0.7.0.sql index d5b4e6b..fc94da5 100644 --- a/sql/vector--0.6.2--0.7.0.sql +++ b/sql/vector--0.6.2--0.7.0.sql @@ -310,6 +310,27 @@ CREATE FUNCTION cosine_distance(sparsevec, sparsevec) RETURNS float8 CREATE FUNCTION sparsevec_norm(sparsevec) RETURNS float8 AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; +CREATE FUNCTION sparsevec_lt(sparsevec, sparsevec) RETURNS bool + AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + +CREATE FUNCTION sparsevec_le(sparsevec, sparsevec) RETURNS bool + AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + +CREATE FUNCTION sparsevec_eq(sparsevec, sparsevec) RETURNS bool + AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + +CREATE FUNCTION sparsevec_ne(sparsevec, sparsevec) RETURNS bool + AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + +CREATE FUNCTION sparsevec_ge(sparsevec, sparsevec) RETURNS bool + AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + +CREATE FUNCTION sparsevec_gt(sparsevec, sparsevec) RETURNS bool + AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + +CREATE FUNCTION sparsevec_cmp(sparsevec, sparsevec) RETURNS int4 + AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + CREATE FUNCTION sparsevec_l2_squared_distance(sparsevec, sparsevec) RETURNS float8 AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; @@ -329,7 +350,7 @@ CREATE CAST (sparsevec AS sparsevec) WITH FUNCTION sparsevec(sparsevec, integer, boolean) AS IMPLICIT; CREATE CAST (sparsevec AS vector) - WITH FUNCTION sparsevec_to_vector(sparsevec, integer, boolean) AS IMPLICIT; + WITH FUNCTION sparsevec_to_vector(sparsevec, integer, boolean) AS ASSIGNMENT; CREATE CAST (vector AS sparsevec) WITH FUNCTION vector_to_sparsevec(vector, integer, boolean) AS IMPLICIT; @@ -349,6 +370,42 @@ CREATE OPERATOR <=> ( COMMUTATOR = '<=>' ); +CREATE OPERATOR < ( + LEFTARG = sparsevec, RIGHTARG = sparsevec, PROCEDURE = sparsevec_lt, + COMMUTATOR = > , NEGATOR = >= , + RESTRICT = scalarltsel, JOIN = scalarltjoinsel +); + +CREATE OPERATOR <= ( + LEFTARG = sparsevec, RIGHTARG = sparsevec, PROCEDURE = sparsevec_le, + COMMUTATOR = >= , NEGATOR = > , + RESTRICT = scalarlesel, JOIN = scalarlejoinsel +); + +CREATE OPERATOR = ( + LEFTARG = sparsevec, RIGHTARG = sparsevec, PROCEDURE = sparsevec_eq, + COMMUTATOR = = , NEGATOR = <> , + RESTRICT = eqsel, JOIN = eqjoinsel +); + +CREATE OPERATOR <> ( + LEFTARG = sparsevec, RIGHTARG = sparsevec, PROCEDURE = sparsevec_ne, + COMMUTATOR = <> , NEGATOR = = , + RESTRICT = eqsel, JOIN = eqjoinsel +); + +CREATE OPERATOR >= ( + LEFTARG = sparsevec, RIGHTARG = sparsevec, PROCEDURE = sparsevec_ge, + COMMUTATOR = <= , NEGATOR = < , + RESTRICT = scalargesel, JOIN = scalargejoinsel +); + +CREATE OPERATOR > ( + LEFTARG = sparsevec, RIGHTARG = sparsevec, PROCEDURE = sparsevec_gt, + COMMUTATOR = < , NEGATOR = <= , + RESTRICT = scalargtsel, JOIN = scalargtjoinsel +); + CREATE OPERATOR CLASS sparsevec_l2_ops FOR TYPE sparsevec USING hnsw AS OPERATOR 1 <-> (sparsevec, sparsevec) FOR ORDER BY float_ops, diff --git a/sql/vector.sql b/sql/vector.sql index 397c839..4568a83 100644 --- a/sql/vector.sql +++ b/sql/vector.sql @@ -621,6 +621,27 @@ CREATE FUNCTION sparsevec_norm(sparsevec) RETURNS float8 -- sparsevec private functions +CREATE FUNCTION sparsevec_lt(sparsevec, sparsevec) RETURNS bool + AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + +CREATE FUNCTION sparsevec_le(sparsevec, sparsevec) RETURNS bool + AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + +CREATE FUNCTION sparsevec_eq(sparsevec, sparsevec) RETURNS bool + AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + +CREATE FUNCTION sparsevec_ne(sparsevec, sparsevec) RETURNS bool + AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + +CREATE FUNCTION sparsevec_ge(sparsevec, sparsevec) RETURNS bool + AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + +CREATE FUNCTION sparsevec_gt(sparsevec, sparsevec) RETURNS bool + AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + +CREATE FUNCTION sparsevec_cmp(sparsevec, sparsevec) RETURNS int4 + AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + CREATE FUNCTION sparsevec_l2_squared_distance(sparsevec, sparsevec) RETURNS float8 AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; @@ -644,7 +665,7 @@ CREATE CAST (sparsevec AS sparsevec) WITH FUNCTION sparsevec(sparsevec, integer, boolean) AS IMPLICIT; CREATE CAST (sparsevec AS vector) - WITH FUNCTION sparsevec_to_vector(sparsevec, integer, boolean) AS IMPLICIT; + WITH FUNCTION sparsevec_to_vector(sparsevec, integer, boolean) AS ASSIGNMENT; CREATE CAST (vector AS sparsevec) WITH FUNCTION vector_to_sparsevec(vector, integer, boolean) AS IMPLICIT; @@ -666,6 +687,42 @@ CREATE OPERATOR <=> ( COMMUTATOR = '<=>' ); +CREATE OPERATOR < ( + LEFTARG = sparsevec, RIGHTARG = sparsevec, PROCEDURE = sparsevec_lt, + COMMUTATOR = > , NEGATOR = >= , + RESTRICT = scalarltsel, JOIN = scalarltjoinsel +); + +CREATE OPERATOR <= ( + LEFTARG = sparsevec, RIGHTARG = sparsevec, PROCEDURE = sparsevec_le, + COMMUTATOR = >= , NEGATOR = > , + RESTRICT = scalarlesel, JOIN = scalarlejoinsel +); + +CREATE OPERATOR = ( + LEFTARG = sparsevec, RIGHTARG = sparsevec, PROCEDURE = sparsevec_eq, + COMMUTATOR = = , NEGATOR = <> , + RESTRICT = eqsel, JOIN = eqjoinsel +); + +CREATE OPERATOR <> ( + LEFTARG = sparsevec, RIGHTARG = sparsevec, PROCEDURE = sparsevec_ne, + COMMUTATOR = <> , NEGATOR = = , + RESTRICT = eqsel, JOIN = eqjoinsel +); + +CREATE OPERATOR >= ( + LEFTARG = sparsevec, RIGHTARG = sparsevec, PROCEDURE = sparsevec_ge, + COMMUTATOR = <= , NEGATOR = < , + RESTRICT = scalargesel, JOIN = scalargejoinsel +); + +CREATE OPERATOR > ( + LEFTARG = sparsevec, RIGHTARG = sparsevec, PROCEDURE = sparsevec_gt, + COMMUTATOR = < , NEGATOR = <= , + RESTRICT = scalargtsel, JOIN = scalargtjoinsel +); + -- sparsevec opclasses CREATE OPERATOR CLASS sparsevec_l2_ops diff --git a/src/sparsevec.c b/src/sparsevec.c index 3a51b4d..af51c95 100644 --- a/src/sparsevec.c +++ b/src/sparsevec.c @@ -798,3 +798,135 @@ sparsevec_norm(PG_FUNCTION_ARGS) PG_RETURN_FLOAT8(sqrt(norm)); } + +/* + * Internal helper to compare sparse vectors + */ +static int +sparsevec_cmp_internal(SparseVector * a, SparseVector * b) +{ + float *ax = SPARSEVEC_VALUES(a); + float *bx = SPARSEVEC_VALUES(b); + int nnz = Min(a->nnz, b->nnz); + + /* Check values before dimensions to be consistent with Postgres arrays */ + for (int i = 0; i < nnz; i++) + { + if (a->indices[i] < b->indices[i]) + return ax[i] < 0 ? -1 : 1; + + if (a->indices[i] > b->indices[i]) + return bx[i] < 0 ? 1 : -1; + + if (ax[i] < bx[i]) + return -1; + + if (ax[i] > bx[i]) + return 1; + } + + if (a->nnz < b->nnz) + return bx[nnz + 1] < 0 ? 1 : -1; + + if (a->nnz > b->nnz) + return ax[nnz + 1] < 0 ? -1 : 1; + + if (a->dim < b->dim) + return -1; + + if (a->dim > b->dim) + return 1; + + return 0; +} + +/* + * Less than + */ +PGDLLEXPORT PG_FUNCTION_INFO_V1(sparsevec_lt); +Datum +sparsevec_lt(PG_FUNCTION_ARGS) +{ + SparseVector *a = PG_GETARG_SPARSEVEC_P(0); + SparseVector *b = PG_GETARG_SPARSEVEC_P(1); + + PG_RETURN_BOOL(sparsevec_cmp_internal(a, b) < 0); +} + +/* + * Less than or equal + */ +PGDLLEXPORT PG_FUNCTION_INFO_V1(sparsevec_le); +Datum +sparsevec_le(PG_FUNCTION_ARGS) +{ + SparseVector *a = PG_GETARG_SPARSEVEC_P(0); + SparseVector *b = PG_GETARG_SPARSEVEC_P(1); + + PG_RETURN_BOOL(sparsevec_cmp_internal(a, b) <= 0); +} + +/* + * Equal + */ +PGDLLEXPORT PG_FUNCTION_INFO_V1(sparsevec_eq); +Datum +sparsevec_eq(PG_FUNCTION_ARGS) +{ + SparseVector *a = PG_GETARG_SPARSEVEC_P(0); + SparseVector *b = PG_GETARG_SPARSEVEC_P(1); + + PG_RETURN_BOOL(sparsevec_cmp_internal(a, b) == 0); +} + +/* + * Not equal + */ +PGDLLEXPORT PG_FUNCTION_INFO_V1(sparsevec_ne); +Datum +sparsevec_ne(PG_FUNCTION_ARGS) +{ + SparseVector *a = PG_GETARG_SPARSEVEC_P(0); + SparseVector *b = PG_GETARG_SPARSEVEC_P(1); + + PG_RETURN_BOOL(sparsevec_cmp_internal(a, b) != 0); +} + +/* + * Greater than or equal + */ +PGDLLEXPORT PG_FUNCTION_INFO_V1(sparsevec_ge); +Datum +sparsevec_ge(PG_FUNCTION_ARGS) +{ + SparseVector *a = PG_GETARG_SPARSEVEC_P(0); + SparseVector *b = PG_GETARG_SPARSEVEC_P(1); + + PG_RETURN_BOOL(sparsevec_cmp_internal(a, b) >= 0); +} + +/* + * Greater than + */ +PGDLLEXPORT PG_FUNCTION_INFO_V1(sparsevec_gt); +Datum +sparsevec_gt(PG_FUNCTION_ARGS) +{ + SparseVector *a = PG_GETARG_SPARSEVEC_P(0); + SparseVector *b = PG_GETARG_SPARSEVEC_P(1); + + PG_RETURN_BOOL(sparsevec_cmp_internal(a, b) > 0); +} + +/* + * Compare sparse vectors + */ +PGDLLEXPORT PG_FUNCTION_INFO_V1(sparsevec_cmp); +Datum +sparsevec_cmp(PG_FUNCTION_ARGS) +{ + SparseVector *a = PG_GETARG_SPARSEVEC_P(0); + SparseVector *b = PG_GETARG_SPARSEVEC_P(1); + + PG_RETURN_INT32(sparsevec_cmp_internal(a, b)); +} diff --git a/test/expected/sparsevec_functions.out b/test/expected/sparsevec_functions.out index c917383..9f8a279 100644 --- a/test/expected/sparsevec_functions.out +++ b/test/expected/sparsevec_functions.out @@ -1,3 +1,117 @@ +SELECT '{1:1,2:2,3:3}/3'::sparsevec < '{1:1,2:2,3:3}/3'; + ?column? +---------- + f +(1 row) + +SELECT '{1:1,2:2,3:3}/3'::sparsevec < '{1:1,2:2}/2'; + ?column? +---------- + f +(1 row) + +SELECT '{1:1,2:2,3:3}/3'::sparsevec <= '{1:1,2:2,3:3}/3'; + ?column? +---------- + t +(1 row) + +SELECT '{1:1,2:2,3:3}/3'::sparsevec <= '{1:1,2:2}/2'; + ?column? +---------- + f +(1 row) + +SELECT '{1:1,2:2,3:3}/3'::sparsevec = '{1:1,2:2,3:3}/3'; + ?column? +---------- + t +(1 row) + +SELECT '{1:1,2:2,3:3}/3'::sparsevec = '{1:1,2:2}/2'; + ?column? +---------- + f +(1 row) + +SELECT '{1:1,2:2,3:3}/3'::sparsevec != '{1:1,2:2,3:3}/3'; + ?column? +---------- + f +(1 row) + +SELECT '{1:1,2:2,3:3}/3'::sparsevec != '{1:1,2:2}/2'; + ?column? +---------- + t +(1 row) + +SELECT '{1:1,2:2,3:3}/3'::sparsevec >= '{1:1,2:2,3:3}/3'; + ?column? +---------- + t +(1 row) + +SELECT '{1:1,2:2,3:3}/3'::sparsevec >= '{1:1,2:2}/2'; + ?column? +---------- + t +(1 row) + +SELECT '{1:1,2:2,3:3}/3'::sparsevec > '{1:1,2:2,3:3}/3'; + ?column? +---------- + f +(1 row) + +SELECT '{1:1,2:2,3:3}/3'::sparsevec > '{1:1,2:2}/2'; + ?column? +---------- + t +(1 row) + +SELECT sparsevec_cmp('{1:1,2:2,3:3}/3', '{1:1,2:2,3:3}/3'); + sparsevec_cmp +--------------- + 0 +(1 row) + +SELECT sparsevec_cmp('{1:1,2:2,3:3}/3', '{}/3'); + sparsevec_cmp +--------------- + 1 +(1 row) + +SELECT sparsevec_cmp('{}/3', '{1:1,2:2,3:3}/3'); + sparsevec_cmp +--------------- + -1 +(1 row) + +SELECT sparsevec_cmp('{1:1,2:2}/2', '{1:1,2:2,3:3}/3'); + sparsevec_cmp +--------------- + -1 +(1 row) + +SELECT sparsevec_cmp('{1:1,2:2,3:3}/3', '{1:1,2:2}/2'); + sparsevec_cmp +--------------- + 1 +(1 row) + +SELECT sparsevec_cmp('{1:1,2:2}/2', '{1:2,2:3,3:4}/3'); + sparsevec_cmp +--------------- + -1 +(1 row) + +SELECT sparsevec_cmp('{1:2,2:3}/2', '{1:1,2:2,3:3}/3'); + sparsevec_cmp +--------------- + 1 +(1 row) + SELECT round(sparsevec_norm('{1:1,2:1}/2')::numeric, 5); round --------- diff --git a/test/sql/sparsevec_functions.sql b/test/sql/sparsevec_functions.sql index 436d059..c7515b5 100644 --- a/test/sql/sparsevec_functions.sql +++ b/test/sql/sparsevec_functions.sql @@ -1,3 +1,24 @@ +SELECT '{1:1,2:2,3:3}/3'::sparsevec < '{1:1,2:2,3:3}/3'; +SELECT '{1:1,2:2,3:3}/3'::sparsevec < '{1:1,2:2}/2'; +SELECT '{1:1,2:2,3:3}/3'::sparsevec <= '{1:1,2:2,3:3}/3'; +SELECT '{1:1,2:2,3:3}/3'::sparsevec <= '{1:1,2:2}/2'; +SELECT '{1:1,2:2,3:3}/3'::sparsevec = '{1:1,2:2,3:3}/3'; +SELECT '{1:1,2:2,3:3}/3'::sparsevec = '{1:1,2:2}/2'; +SELECT '{1:1,2:2,3:3}/3'::sparsevec != '{1:1,2:2,3:3}/3'; +SELECT '{1:1,2:2,3:3}/3'::sparsevec != '{1:1,2:2}/2'; +SELECT '{1:1,2:2,3:3}/3'::sparsevec >= '{1:1,2:2,3:3}/3'; +SELECT '{1:1,2:2,3:3}/3'::sparsevec >= '{1:1,2:2}/2'; +SELECT '{1:1,2:2,3:3}/3'::sparsevec > '{1:1,2:2,3:3}/3'; +SELECT '{1:1,2:2,3:3}/3'::sparsevec > '{1:1,2:2}/2'; + +SELECT sparsevec_cmp('{1:1,2:2,3:3}/3', '{1:1,2:2,3:3}/3'); +SELECT sparsevec_cmp('{1:1,2:2,3:3}/3', '{}/3'); +SELECT sparsevec_cmp('{}/3', '{1:1,2:2,3:3}/3'); +SELECT sparsevec_cmp('{1:1,2:2}/2', '{1:1,2:2,3:3}/3'); +SELECT sparsevec_cmp('{1:1,2:2,3:3}/3', '{1:1,2:2}/2'); +SELECT sparsevec_cmp('{1:1,2:2}/2', '{1:2,2:3,3:4}/3'); +SELECT sparsevec_cmp('{1:2,2:3}/2', '{1:1,2:2,3:3}/3'); + SELECT round(sparsevec_norm('{1:1,2:1}/2')::numeric, 5); SELECT sparsevec_norm('{1:3,2:4}/2'); SELECT sparsevec_norm('{2:1}/2');