diff --git a/sql/vector.sql b/sql/vector.sql index 198f6e7..f3d17aa 100644 --- a/sql/vector.sql +++ b/sql/vector.sql @@ -697,6 +697,27 @@ CREATE FUNCTION l2_norm(intvec) RETURNS float8 -- intvec private functions +CREATE FUNCTION intvec_lt(intvec, intvec) RETURNS bool + AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + +CREATE FUNCTION intvec_le(intvec, intvec) RETURNS bool + AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + +CREATE FUNCTION intvec_eq(intvec, intvec) RETURNS bool + AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + +CREATE FUNCTION intvec_ne(intvec, intvec) RETURNS bool + AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + +CREATE FUNCTION intvec_ge(intvec, intvec) RETURNS bool + AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + +CREATE FUNCTION intvec_gt(intvec, intvec) RETURNS bool + AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + +CREATE FUNCTION intvec_cmp(intvec, intvec) RETURNS int4 + AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + CREATE FUNCTION intvec_l2_squared_distance(intvec, intvec) RETURNS float8 AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; @@ -741,8 +762,53 @@ CREATE OPERATOR <+> ( COMMUTATOR = '<+>' ); +CREATE OPERATOR < ( + LEFTARG = intvec, RIGHTARG = intvec, PROCEDURE = intvec_lt, + COMMUTATOR = > , NEGATOR = >= , + RESTRICT = scalarltsel, JOIN = scalarltjoinsel +); + +CREATE OPERATOR <= ( + LEFTARG = intvec, RIGHTARG = intvec, PROCEDURE = intvec_le, + COMMUTATOR = >= , NEGATOR = > , + RESTRICT = scalarlesel, JOIN = scalarlejoinsel +); + +CREATE OPERATOR = ( + LEFTARG = intvec, RIGHTARG = intvec, PROCEDURE = intvec_eq, + COMMUTATOR = = , NEGATOR = <> , + RESTRICT = eqsel, JOIN = eqjoinsel +); + +CREATE OPERATOR <> ( + LEFTARG = intvec, RIGHTARG = intvec, PROCEDURE = intvec_ne, + COMMUTATOR = <> , NEGATOR = = , + RESTRICT = eqsel, JOIN = eqjoinsel +); + +CREATE OPERATOR >= ( + LEFTARG = intvec, RIGHTARG = intvec, PROCEDURE = intvec_ge, + COMMUTATOR = <= , NEGATOR = < , + RESTRICT = scalargesel, JOIN = scalargejoinsel +); + +CREATE OPERATOR > ( + LEFTARG = intvec, RIGHTARG = intvec, PROCEDURE = intvec_gt, + COMMUTATOR = < , NEGATOR = <= , + RESTRICT = scalargtsel, JOIN = scalargtjoinsel +); + -- intvec opclasses +CREATE OPERATOR CLASS intvec_ops + DEFAULT FOR TYPE intvec USING btree AS + OPERATOR 1 < , + OPERATOR 2 <= , + OPERATOR 3 = , + OPERATOR 4 >= , + OPERATOR 5 > , + FUNCTION 1 intvec_cmp(intvec, intvec); + CREATE OPERATOR CLASS intvec_l2_ops FOR TYPE intvec USING hnsw AS OPERATOR 1 <-> (intvec, intvec) FOR ORDER BY float_ops, diff --git a/src/intvec.c b/src/intvec.c index 79bfed2..68209a8 100644 --- a/src/intvec.c +++ b/src/intvec.c @@ -597,3 +597,121 @@ intvec_l2_norm(PG_FUNCTION_ARGS) PG_RETURN_FLOAT8(sqrt((double) norm)); } + +/* + * Internal helper to compare int vectors + */ +static int +intvec_cmp_internal(IntVector * a, IntVector * b) +{ + int dim = Min(a->dim, b->dim); + + /* Check values before dimensions to be consistent with Postgres arrays */ + for (int i = 0; i < dim; i++) + { + if (a->x[i] < b->x[i]) + return -1; + + if (a->x[i] > b->x[i]) + return 1; + } + + if (a->dim < b->dim) + return -1; + + if (a->dim > b->dim) + return 1; + + return 0; +} + +/* + * Less than + */ +FUNCTION_PREFIX PG_FUNCTION_INFO_V1(intvec_lt); +Datum +intvec_lt(PG_FUNCTION_ARGS) +{ + IntVector *a = PG_GETARG_INTVEC_P(0); + IntVector *b = PG_GETARG_INTVEC_P(1); + + PG_RETURN_BOOL(intvec_cmp_internal(a, b) < 0); +} + +/* + * Less than or equal + */ +FUNCTION_PREFIX PG_FUNCTION_INFO_V1(intvec_le); +Datum +intvec_le(PG_FUNCTION_ARGS) +{ + IntVector *a = PG_GETARG_INTVEC_P(0); + IntVector *b = PG_GETARG_INTVEC_P(1); + + PG_RETURN_BOOL(intvec_cmp_internal(a, b) <= 0); +} + +/* + * Equal + */ +FUNCTION_PREFIX PG_FUNCTION_INFO_V1(intvec_eq); +Datum +intvec_eq(PG_FUNCTION_ARGS) +{ + IntVector *a = PG_GETARG_INTVEC_P(0); + IntVector *b = PG_GETARG_INTVEC_P(1); + + PG_RETURN_BOOL(intvec_cmp_internal(a, b) == 0); +} + +/* + * Not equal + */ +FUNCTION_PREFIX PG_FUNCTION_INFO_V1(intvec_ne); +Datum +intvec_ne(PG_FUNCTION_ARGS) +{ + IntVector *a = PG_GETARG_INTVEC_P(0); + IntVector *b = PG_GETARG_INTVEC_P(1); + + PG_RETURN_BOOL(intvec_cmp_internal(a, b) != 0); +} + +/* + * Greater than or equal + */ +FUNCTION_PREFIX PG_FUNCTION_INFO_V1(intvec_ge); +Datum +intvec_ge(PG_FUNCTION_ARGS) +{ + IntVector *a = PG_GETARG_INTVEC_P(0); + IntVector *b = PG_GETARG_INTVEC_P(1); + + PG_RETURN_BOOL(intvec_cmp_internal(a, b) >= 0); +} + +/* + * Greater than + */ +FUNCTION_PREFIX PG_FUNCTION_INFO_V1(intvec_gt); +Datum +intvec_gt(PG_FUNCTION_ARGS) +{ + IntVector *a = PG_GETARG_INTVEC_P(0); + IntVector *b = PG_GETARG_INTVEC_P(1); + + PG_RETURN_BOOL(intvec_cmp_internal(a, b) > 0); +} + +/* + * Compare int vectors + */ +FUNCTION_PREFIX PG_FUNCTION_INFO_V1(intvec_cmp); +Datum +intvec_cmp(PG_FUNCTION_ARGS) +{ + IntVector *a = PG_GETARG_INTVEC_P(0); + IntVector *b = PG_GETARG_INTVEC_P(1); + + PG_RETURN_INT32(intvec_cmp_internal(a, b)); +} diff --git a/test/expected/copy.out b/test/expected/copy.out index b8ee75c..d304f5d 100644 --- a/test/expected/copy.out +++ b/test/expected/copy.out @@ -39,10 +39,14 @@ CREATE TABLE t2 (val intvec(3)); \copy t TO 'results/intvec.bin' WITH (FORMAT binary) \copy t2 FROM 'results/intvec.bin' WITH (FORMAT binary) SELECT * FROM t2 ORDER BY val; -ERROR: could not identify an ordering operator for type intvec -LINE 1: SELECT * FROM t2 ORDER BY val; - ^ -HINT: Use an explicit ordering operator or modify the query. + val +--------- + [0,0,0] + [1,1,1] + [1,2,3] + +(4 rows) + DROP TABLE t; DROP TABLE t2; -- sparsevec diff --git a/test/expected/intvec.out b/test/expected/intvec.out index 32a9326..2e84d68 100644 --- a/test/expected/intvec.out +++ b/test/expected/intvec.out @@ -114,6 +114,120 @@ SELECT unnest('{"[1,2,3]", "[4,5,6]"}'::intvec[]); SELECT '{"[1,2,3]"}'::intvec(2)[]; ERROR: expected 2 dimensions, not 3 +SELECT '[1,2,3]'::intvec < '[1,2,3]'; + ?column? +---------- + f +(1 row) + +SELECT '[1,2,3]'::intvec < '[1,2]'; + ?column? +---------- + f +(1 row) + +SELECT '[1,2,3]'::intvec <= '[1,2,3]'; + ?column? +---------- + t +(1 row) + +SELECT '[1,2,3]'::intvec <= '[1,2]'; + ?column? +---------- + f +(1 row) + +SELECT '[1,2,3]'::intvec = '[1,2,3]'; + ?column? +---------- + t +(1 row) + +SELECT '[1,2,3]'::intvec = '[1,2]'; + ?column? +---------- + f +(1 row) + +SELECT '[1,2,3]'::intvec != '[1,2,3]'; + ?column? +---------- + f +(1 row) + +SELECT '[1,2,3]'::intvec != '[1,2]'; + ?column? +---------- + t +(1 row) + +SELECT '[1,2,3]'::intvec >= '[1,2,3]'; + ?column? +---------- + t +(1 row) + +SELECT '[1,2,3]'::intvec >= '[1,2]'; + ?column? +---------- + t +(1 row) + +SELECT '[1,2,3]'::intvec > '[1,2,3]'; + ?column? +---------- + f +(1 row) + +SELECT '[1,2,3]'::intvec > '[1,2]'; + ?column? +---------- + t +(1 row) + +SELECT intvec_cmp('[1,2,3]', '[1,2,3]'); + intvec_cmp +------------ + 0 +(1 row) + +SELECT intvec_cmp('[1,2,3]', '[0,0,0]'); + intvec_cmp +------------ + 1 +(1 row) + +SELECT intvec_cmp('[0,0,0]', '[1,2,3]'); + intvec_cmp +------------ + -1 +(1 row) + +SELECT intvec_cmp('[1,2]', '[1,2,3]'); + intvec_cmp +------------ + -1 +(1 row) + +SELECT intvec_cmp('[1,2,3]', '[1,2]'); + intvec_cmp +------------ + 1 +(1 row) + +SELECT intvec_cmp('[1,2]', '[2,3,4]'); + intvec_cmp +------------ + -1 +(1 row) + +SELECT intvec_cmp('[2,3]', '[1,2,3]'); + intvec_cmp +------------ + 1 +(1 row) + SELECT l2_distance('[0,0]'::intvec, '[3,4]'); l2_distance ------------- diff --git a/test/sql/intvec.sql b/test/sql/intvec.sql index 0a76d0b..bbc67db 100644 --- a/test/sql/intvec.sql +++ b/test/sql/intvec.sql @@ -27,6 +27,27 @@ SELECT '[1,2,3]'::intvec(16001); SELECT unnest('{"[1,2,3]", "[4,5,6]"}'::intvec[]); SELECT '{"[1,2,3]"}'::intvec(2)[]; +SELECT '[1,2,3]'::intvec < '[1,2,3]'; +SELECT '[1,2,3]'::intvec < '[1,2]'; +SELECT '[1,2,3]'::intvec <= '[1,2,3]'; +SELECT '[1,2,3]'::intvec <= '[1,2]'; +SELECT '[1,2,3]'::intvec = '[1,2,3]'; +SELECT '[1,2,3]'::intvec = '[1,2]'; +SELECT '[1,2,3]'::intvec != '[1,2,3]'; +SELECT '[1,2,3]'::intvec != '[1,2]'; +SELECT '[1,2,3]'::intvec >= '[1,2,3]'; +SELECT '[1,2,3]'::intvec >= '[1,2]'; +SELECT '[1,2,3]'::intvec > '[1,2,3]'; +SELECT '[1,2,3]'::intvec > '[1,2]'; + +SELECT intvec_cmp('[1,2,3]', '[1,2,3]'); +SELECT intvec_cmp('[1,2,3]', '[0,0,0]'); +SELECT intvec_cmp('[0,0,0]', '[1,2,3]'); +SELECT intvec_cmp('[1,2]', '[1,2,3]'); +SELECT intvec_cmp('[1,2,3]', '[1,2]'); +SELECT intvec_cmp('[1,2]', '[2,3,4]'); +SELECT intvec_cmp('[2,3]', '[1,2,3]'); + SELECT l2_distance('[0,0]'::intvec, '[3,4]'); SELECT l2_distance('[0,0]'::intvec, '[0,1]'); SELECT l2_distance('[1,2]'::intvec, '[3]');