From bc199a33cddb3097ad93fb12309904a2cff64289 Mon Sep 17 00:00:00 2001 From: Andrew Kane Date: Sun, 14 Apr 2024 15:16:17 -0700 Subject: [PATCH] Added sum for half vectors --- README.md | 1 + sql/vector--0.6.2--0.7.0.sql | 7 +++++++ sql/vector.sql | 7 +++++++ test/expected/halfvec_functions.out | 22 ++++++++++++++++++++++ test/sql/halfvec_functions.sql | 6 ++++++ 5 files changed, 43 insertions(+) diff --git a/README.md b/README.md index 3ce6e7f..b36270b 100644 --- a/README.md +++ b/README.md @@ -907,6 +907,7 @@ subvector(halfvec, integer, integer) → halfvec | subvector | unreleased Function | Description | Added --- | --- | --- avg(halfvec) → halfvec | average | unreleased +sum(halfvec) → halfvec | sum | unreleased ### Bit Type diff --git a/sql/vector--0.6.2--0.7.0.sql b/sql/vector--0.6.2--0.7.0.sql index 5be4006..8d13060 100644 --- a/sql/vector--0.6.2--0.7.0.sql +++ b/sql/vector--0.6.2--0.7.0.sql @@ -137,6 +137,13 @@ CREATE AGGREGATE avg(halfvec) ( PARALLEL = SAFE ); +CREATE AGGREGATE sum(halfvec) ( + SFUNC = halfvec_add, + STYPE = halfvec, + COMBINEFUNC = halfvec_add, + PARALLEL = SAFE +); + CREATE FUNCTION halfvec(halfvec, integer, boolean) RETURNS halfvec AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; diff --git a/sql/vector.sql b/sql/vector.sql index 900a23c..e183c1f 100644 --- a/sql/vector.sql +++ b/sql/vector.sql @@ -434,6 +434,13 @@ CREATE AGGREGATE avg(halfvec) ( PARALLEL = SAFE ); +CREATE AGGREGATE sum(halfvec) ( + SFUNC = halfvec_add, + STYPE = halfvec, + COMBINEFUNC = halfvec_add, + PARALLEL = SAFE +); + -- halfvec cast functions CREATE FUNCTION halfvec(halfvec, integer, boolean) RETURNS halfvec diff --git a/test/expected/halfvec_functions.out b/test/expected/halfvec_functions.out index a71c79e..8067d46 100644 --- a/test/expected/halfvec_functions.out +++ b/test/expected/halfvec_functions.out @@ -348,3 +348,25 @@ SELECT avg(v) FROM unnest(ARRAY['[65504]'::halfvec, '[65504]']) v; SELECT halfvec_avg(array_agg(n)) FROM generate_series(1, 16002) n; ERROR: halfvec cannot have more than 16000 dimensions +SELECT sum(v) FROM unnest(ARRAY['[1,2,3]'::halfvec, '[3,5,7]']) v; + sum +---------- + [4,7,10] +(1 row) + +SELECT sum(v) FROM unnest(ARRAY['[1,2,3]'::halfvec, '[3,5,7]', NULL]) v; + sum +---------- + [4,7,10] +(1 row) + +SELECT sum(v) FROM unnest(ARRAY[]::halfvec[]) v; + sum +----- + +(1 row) + +SELECT sum(v) FROM unnest(ARRAY['[1,2]'::halfvec, '[3]']) v; +ERROR: different halfvec dimensions 2 and 1 +SELECT sum(v) FROM unnest(ARRAY['[65504]'::halfvec, '[65504]']) v; +ERROR: value out of range: overflow diff --git a/test/sql/halfvec_functions.sql b/test/sql/halfvec_functions.sql index d5cbcbb..430d444 100644 --- a/test/sql/halfvec_functions.sql +++ b/test/sql/halfvec_functions.sql @@ -76,3 +76,9 @@ SELECT avg(v) FROM unnest(ARRAY[]::halfvec[]) v; SELECT avg(v) FROM unnest(ARRAY['[1,2]'::halfvec, '[3]']) v; SELECT avg(v) FROM unnest(ARRAY['[65504]'::halfvec, '[65504]']) v; SELECT halfvec_avg(array_agg(n)) FROM generate_series(1, 16002) n; + +SELECT sum(v) FROM unnest(ARRAY['[1,2,3]'::halfvec, '[3,5,7]']) v; +SELECT sum(v) FROM unnest(ARRAY['[1,2,3]'::halfvec, '[3,5,7]', NULL]) v; +SELECT sum(v) FROM unnest(ARRAY[]::halfvec[]) v; +SELECT sum(v) FROM unnest(ARRAY['[1,2]'::halfvec, '[3]']) v; +SELECT sum(v) FROM unnest(ARRAY['[65504]'::halfvec, '[65504]']) v;