Added comparison operators for sparsevec

This commit is contained in:
Andrew Kane
2024-04-14 13:40:37 -07:00
parent c68c2867fd
commit 88788472ba
5 changed files with 383 additions and 2 deletions

View File

@@ -310,6 +310,27 @@ CREATE FUNCTION cosine_distance(sparsevec, sparsevec) RETURNS float8
CREATE FUNCTION sparsevec_norm(sparsevec) RETURNS float8 CREATE FUNCTION sparsevec_norm(sparsevec) RETURNS float8
AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE;
CREATE FUNCTION sparsevec_lt(sparsevec, sparsevec) RETURNS bool
AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE;
CREATE FUNCTION sparsevec_le(sparsevec, sparsevec) RETURNS bool
AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE;
CREATE FUNCTION sparsevec_eq(sparsevec, sparsevec) RETURNS bool
AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE;
CREATE FUNCTION sparsevec_ne(sparsevec, sparsevec) RETURNS bool
AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE;
CREATE FUNCTION sparsevec_ge(sparsevec, sparsevec) RETURNS bool
AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE;
CREATE FUNCTION sparsevec_gt(sparsevec, sparsevec) RETURNS bool
AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE;
CREATE FUNCTION sparsevec_cmp(sparsevec, sparsevec) RETURNS int4
AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE;
CREATE FUNCTION sparsevec_l2_squared_distance(sparsevec, sparsevec) RETURNS float8 CREATE FUNCTION sparsevec_l2_squared_distance(sparsevec, sparsevec) RETURNS float8
AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE;
@@ -329,7 +350,7 @@ CREATE CAST (sparsevec AS sparsevec)
WITH FUNCTION sparsevec(sparsevec, integer, boolean) AS IMPLICIT; WITH FUNCTION sparsevec(sparsevec, integer, boolean) AS IMPLICIT;
CREATE CAST (sparsevec AS vector) CREATE CAST (sparsevec AS vector)
WITH FUNCTION sparsevec_to_vector(sparsevec, integer, boolean) AS IMPLICIT; WITH FUNCTION sparsevec_to_vector(sparsevec, integer, boolean) AS ASSIGNMENT;
CREATE CAST (vector AS sparsevec) CREATE CAST (vector AS sparsevec)
WITH FUNCTION vector_to_sparsevec(vector, integer, boolean) AS IMPLICIT; WITH FUNCTION vector_to_sparsevec(vector, integer, boolean) AS IMPLICIT;
@@ -349,6 +370,42 @@ CREATE OPERATOR <=> (
COMMUTATOR = '<=>' COMMUTATOR = '<=>'
); );
CREATE OPERATOR < (
LEFTARG = sparsevec, RIGHTARG = sparsevec, PROCEDURE = sparsevec_lt,
COMMUTATOR = > , NEGATOR = >= ,
RESTRICT = scalarltsel, JOIN = scalarltjoinsel
);
CREATE OPERATOR <= (
LEFTARG = sparsevec, RIGHTARG = sparsevec, PROCEDURE = sparsevec_le,
COMMUTATOR = >= , NEGATOR = > ,
RESTRICT = scalarlesel, JOIN = scalarlejoinsel
);
CREATE OPERATOR = (
LEFTARG = sparsevec, RIGHTARG = sparsevec, PROCEDURE = sparsevec_eq,
COMMUTATOR = = , NEGATOR = <> ,
RESTRICT = eqsel, JOIN = eqjoinsel
);
CREATE OPERATOR <> (
LEFTARG = sparsevec, RIGHTARG = sparsevec, PROCEDURE = sparsevec_ne,
COMMUTATOR = <> , NEGATOR = = ,
RESTRICT = eqsel, JOIN = eqjoinsel
);
CREATE OPERATOR >= (
LEFTARG = sparsevec, RIGHTARG = sparsevec, PROCEDURE = sparsevec_ge,
COMMUTATOR = <= , NEGATOR = < ,
RESTRICT = scalargesel, JOIN = scalargejoinsel
);
CREATE OPERATOR > (
LEFTARG = sparsevec, RIGHTARG = sparsevec, PROCEDURE = sparsevec_gt,
COMMUTATOR = < , NEGATOR = <= ,
RESTRICT = scalargtsel, JOIN = scalargtjoinsel
);
CREATE OPERATOR CLASS sparsevec_l2_ops CREATE OPERATOR CLASS sparsevec_l2_ops
FOR TYPE sparsevec USING hnsw AS FOR TYPE sparsevec USING hnsw AS
OPERATOR 1 <-> (sparsevec, sparsevec) FOR ORDER BY float_ops, OPERATOR 1 <-> (sparsevec, sparsevec) FOR ORDER BY float_ops,

View File

@@ -621,6 +621,27 @@ CREATE FUNCTION sparsevec_norm(sparsevec) RETURNS float8
-- sparsevec private functions -- sparsevec private functions
CREATE FUNCTION sparsevec_lt(sparsevec, sparsevec) RETURNS bool
AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE;
CREATE FUNCTION sparsevec_le(sparsevec, sparsevec) RETURNS bool
AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE;
CREATE FUNCTION sparsevec_eq(sparsevec, sparsevec) RETURNS bool
AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE;
CREATE FUNCTION sparsevec_ne(sparsevec, sparsevec) RETURNS bool
AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE;
CREATE FUNCTION sparsevec_ge(sparsevec, sparsevec) RETURNS bool
AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE;
CREATE FUNCTION sparsevec_gt(sparsevec, sparsevec) RETURNS bool
AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE;
CREATE FUNCTION sparsevec_cmp(sparsevec, sparsevec) RETURNS int4
AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE;
CREATE FUNCTION sparsevec_l2_squared_distance(sparsevec, sparsevec) RETURNS float8 CREATE FUNCTION sparsevec_l2_squared_distance(sparsevec, sparsevec) RETURNS float8
AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE;
@@ -644,7 +665,7 @@ CREATE CAST (sparsevec AS sparsevec)
WITH FUNCTION sparsevec(sparsevec, integer, boolean) AS IMPLICIT; WITH FUNCTION sparsevec(sparsevec, integer, boolean) AS IMPLICIT;
CREATE CAST (sparsevec AS vector) CREATE CAST (sparsevec AS vector)
WITH FUNCTION sparsevec_to_vector(sparsevec, integer, boolean) AS IMPLICIT; WITH FUNCTION sparsevec_to_vector(sparsevec, integer, boolean) AS ASSIGNMENT;
CREATE CAST (vector AS sparsevec) CREATE CAST (vector AS sparsevec)
WITH FUNCTION vector_to_sparsevec(vector, integer, boolean) AS IMPLICIT; WITH FUNCTION vector_to_sparsevec(vector, integer, boolean) AS IMPLICIT;
@@ -666,6 +687,42 @@ CREATE OPERATOR <=> (
COMMUTATOR = '<=>' COMMUTATOR = '<=>'
); );
CREATE OPERATOR < (
LEFTARG = sparsevec, RIGHTARG = sparsevec, PROCEDURE = sparsevec_lt,
COMMUTATOR = > , NEGATOR = >= ,
RESTRICT = scalarltsel, JOIN = scalarltjoinsel
);
CREATE OPERATOR <= (
LEFTARG = sparsevec, RIGHTARG = sparsevec, PROCEDURE = sparsevec_le,
COMMUTATOR = >= , NEGATOR = > ,
RESTRICT = scalarlesel, JOIN = scalarlejoinsel
);
CREATE OPERATOR = (
LEFTARG = sparsevec, RIGHTARG = sparsevec, PROCEDURE = sparsevec_eq,
COMMUTATOR = = , NEGATOR = <> ,
RESTRICT = eqsel, JOIN = eqjoinsel
);
CREATE OPERATOR <> (
LEFTARG = sparsevec, RIGHTARG = sparsevec, PROCEDURE = sparsevec_ne,
COMMUTATOR = <> , NEGATOR = = ,
RESTRICT = eqsel, JOIN = eqjoinsel
);
CREATE OPERATOR >= (
LEFTARG = sparsevec, RIGHTARG = sparsevec, PROCEDURE = sparsevec_ge,
COMMUTATOR = <= , NEGATOR = < ,
RESTRICT = scalargesel, JOIN = scalargejoinsel
);
CREATE OPERATOR > (
LEFTARG = sparsevec, RIGHTARG = sparsevec, PROCEDURE = sparsevec_gt,
COMMUTATOR = < , NEGATOR = <= ,
RESTRICT = scalargtsel, JOIN = scalargtjoinsel
);
-- sparsevec opclasses -- sparsevec opclasses
CREATE OPERATOR CLASS sparsevec_l2_ops CREATE OPERATOR CLASS sparsevec_l2_ops

View File

@@ -798,3 +798,135 @@ sparsevec_norm(PG_FUNCTION_ARGS)
PG_RETURN_FLOAT8(sqrt(norm)); PG_RETURN_FLOAT8(sqrt(norm));
} }
/*
* Internal helper to compare sparse vectors
*/
static int
sparsevec_cmp_internal(SparseVector * a, SparseVector * b)
{
float *ax = SPARSEVEC_VALUES(a);
float *bx = SPARSEVEC_VALUES(b);
int nnz = Min(a->nnz, b->nnz);
/* Check values before dimensions to be consistent with Postgres arrays */
for (int i = 0; i < nnz; i++)
{
if (a->indices[i] < b->indices[i])
return ax[i] < 0 ? -1 : 1;
if (a->indices[i] > b->indices[i])
return bx[i] < 0 ? 1 : -1;
if (ax[i] < bx[i])
return -1;
if (ax[i] > bx[i])
return 1;
}
if (a->nnz < b->nnz)
return bx[nnz + 1] < 0 ? 1 : -1;
if (a->nnz > b->nnz)
return ax[nnz + 1] < 0 ? -1 : 1;
if (a->dim < b->dim)
return -1;
if (a->dim > b->dim)
return 1;
return 0;
}
/*
* Less than
*/
PGDLLEXPORT PG_FUNCTION_INFO_V1(sparsevec_lt);
Datum
sparsevec_lt(PG_FUNCTION_ARGS)
{
SparseVector *a = PG_GETARG_SPARSEVEC_P(0);
SparseVector *b = PG_GETARG_SPARSEVEC_P(1);
PG_RETURN_BOOL(sparsevec_cmp_internal(a, b) < 0);
}
/*
* Less than or equal
*/
PGDLLEXPORT PG_FUNCTION_INFO_V1(sparsevec_le);
Datum
sparsevec_le(PG_FUNCTION_ARGS)
{
SparseVector *a = PG_GETARG_SPARSEVEC_P(0);
SparseVector *b = PG_GETARG_SPARSEVEC_P(1);
PG_RETURN_BOOL(sparsevec_cmp_internal(a, b) <= 0);
}
/*
* Equal
*/
PGDLLEXPORT PG_FUNCTION_INFO_V1(sparsevec_eq);
Datum
sparsevec_eq(PG_FUNCTION_ARGS)
{
SparseVector *a = PG_GETARG_SPARSEVEC_P(0);
SparseVector *b = PG_GETARG_SPARSEVEC_P(1);
PG_RETURN_BOOL(sparsevec_cmp_internal(a, b) == 0);
}
/*
* Not equal
*/
PGDLLEXPORT PG_FUNCTION_INFO_V1(sparsevec_ne);
Datum
sparsevec_ne(PG_FUNCTION_ARGS)
{
SparseVector *a = PG_GETARG_SPARSEVEC_P(0);
SparseVector *b = PG_GETARG_SPARSEVEC_P(1);
PG_RETURN_BOOL(sparsevec_cmp_internal(a, b) != 0);
}
/*
* Greater than or equal
*/
PGDLLEXPORT PG_FUNCTION_INFO_V1(sparsevec_ge);
Datum
sparsevec_ge(PG_FUNCTION_ARGS)
{
SparseVector *a = PG_GETARG_SPARSEVEC_P(0);
SparseVector *b = PG_GETARG_SPARSEVEC_P(1);
PG_RETURN_BOOL(sparsevec_cmp_internal(a, b) >= 0);
}
/*
* Greater than
*/
PGDLLEXPORT PG_FUNCTION_INFO_V1(sparsevec_gt);
Datum
sparsevec_gt(PG_FUNCTION_ARGS)
{
SparseVector *a = PG_GETARG_SPARSEVEC_P(0);
SparseVector *b = PG_GETARG_SPARSEVEC_P(1);
PG_RETURN_BOOL(sparsevec_cmp_internal(a, b) > 0);
}
/*
* Compare sparse vectors
*/
PGDLLEXPORT PG_FUNCTION_INFO_V1(sparsevec_cmp);
Datum
sparsevec_cmp(PG_FUNCTION_ARGS)
{
SparseVector *a = PG_GETARG_SPARSEVEC_P(0);
SparseVector *b = PG_GETARG_SPARSEVEC_P(1);
PG_RETURN_INT32(sparsevec_cmp_internal(a, b));
}

View File

@@ -1,3 +1,117 @@
SELECT '{1:1,2:2,3:3}/3'::sparsevec < '{1:1,2:2,3:3}/3';
?column?
----------
f
(1 row)
SELECT '{1:1,2:2,3:3}/3'::sparsevec < '{1:1,2:2}/2';
?column?
----------
f
(1 row)
SELECT '{1:1,2:2,3:3}/3'::sparsevec <= '{1:1,2:2,3:3}/3';
?column?
----------
t
(1 row)
SELECT '{1:1,2:2,3:3}/3'::sparsevec <= '{1:1,2:2}/2';
?column?
----------
f
(1 row)
SELECT '{1:1,2:2,3:3}/3'::sparsevec = '{1:1,2:2,3:3}/3';
?column?
----------
t
(1 row)
SELECT '{1:1,2:2,3:3}/3'::sparsevec = '{1:1,2:2}/2';
?column?
----------
f
(1 row)
SELECT '{1:1,2:2,3:3}/3'::sparsevec != '{1:1,2:2,3:3}/3';
?column?
----------
f
(1 row)
SELECT '{1:1,2:2,3:3}/3'::sparsevec != '{1:1,2:2}/2';
?column?
----------
t
(1 row)
SELECT '{1:1,2:2,3:3}/3'::sparsevec >= '{1:1,2:2,3:3}/3';
?column?
----------
t
(1 row)
SELECT '{1:1,2:2,3:3}/3'::sparsevec >= '{1:1,2:2}/2';
?column?
----------
t
(1 row)
SELECT '{1:1,2:2,3:3}/3'::sparsevec > '{1:1,2:2,3:3}/3';
?column?
----------
f
(1 row)
SELECT '{1:1,2:2,3:3}/3'::sparsevec > '{1:1,2:2}/2';
?column?
----------
t
(1 row)
SELECT sparsevec_cmp('{1:1,2:2,3:3}/3', '{1:1,2:2,3:3}/3');
sparsevec_cmp
---------------
0
(1 row)
SELECT sparsevec_cmp('{1:1,2:2,3:3}/3', '{}/3');
sparsevec_cmp
---------------
1
(1 row)
SELECT sparsevec_cmp('{}/3', '{1:1,2:2,3:3}/3');
sparsevec_cmp
---------------
-1
(1 row)
SELECT sparsevec_cmp('{1:1,2:2}/2', '{1:1,2:2,3:3}/3');
sparsevec_cmp
---------------
-1
(1 row)
SELECT sparsevec_cmp('{1:1,2:2,3:3}/3', '{1:1,2:2}/2');
sparsevec_cmp
---------------
1
(1 row)
SELECT sparsevec_cmp('{1:1,2:2}/2', '{1:2,2:3,3:4}/3');
sparsevec_cmp
---------------
-1
(1 row)
SELECT sparsevec_cmp('{1:2,2:3}/2', '{1:1,2:2,3:3}/3');
sparsevec_cmp
---------------
1
(1 row)
SELECT round(sparsevec_norm('{1:1,2:1}/2')::numeric, 5); SELECT round(sparsevec_norm('{1:1,2:1}/2')::numeric, 5);
round round
--------- ---------

View File

@@ -1,3 +1,24 @@
SELECT '{1:1,2:2,3:3}/3'::sparsevec < '{1:1,2:2,3:3}/3';
SELECT '{1:1,2:2,3:3}/3'::sparsevec < '{1:1,2:2}/2';
SELECT '{1:1,2:2,3:3}/3'::sparsevec <= '{1:1,2:2,3:3}/3';
SELECT '{1:1,2:2,3:3}/3'::sparsevec <= '{1:1,2:2}/2';
SELECT '{1:1,2:2,3:3}/3'::sparsevec = '{1:1,2:2,3:3}/3';
SELECT '{1:1,2:2,3:3}/3'::sparsevec = '{1:1,2:2}/2';
SELECT '{1:1,2:2,3:3}/3'::sparsevec != '{1:1,2:2,3:3}/3';
SELECT '{1:1,2:2,3:3}/3'::sparsevec != '{1:1,2:2}/2';
SELECT '{1:1,2:2,3:3}/3'::sparsevec >= '{1:1,2:2,3:3}/3';
SELECT '{1:1,2:2,3:3}/3'::sparsevec >= '{1:1,2:2}/2';
SELECT '{1:1,2:2,3:3}/3'::sparsevec > '{1:1,2:2,3:3}/3';
SELECT '{1:1,2:2,3:3}/3'::sparsevec > '{1:1,2:2}/2';
SELECT sparsevec_cmp('{1:1,2:2,3:3}/3', '{1:1,2:2,3:3}/3');
SELECT sparsevec_cmp('{1:1,2:2,3:3}/3', '{}/3');
SELECT sparsevec_cmp('{}/3', '{1:1,2:2,3:3}/3');
SELECT sparsevec_cmp('{1:1,2:2}/2', '{1:1,2:2,3:3}/3');
SELECT sparsevec_cmp('{1:1,2:2,3:3}/3', '{1:1,2:2}/2');
SELECT sparsevec_cmp('{1:1,2:2}/2', '{1:2,2:3,3:4}/3');
SELECT sparsevec_cmp('{1:2,2:3}/2', '{1:1,2:2,3:3}/3');
SELECT round(sparsevec_norm('{1:1,2:1}/2')::numeric, 5); SELECT round(sparsevec_norm('{1:1,2:1}/2')::numeric, 5);
SELECT sparsevec_norm('{1:3,2:4}/2'); SELECT sparsevec_norm('{1:3,2:4}/2');
SELECT sparsevec_norm('{2:1}/2'); SELECT sparsevec_norm('{2:1}/2');