diff --git a/CHANGELOG.md b/CHANGELOG.md index bfa60e1..e7b6f37 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -3,6 +3,7 @@ - Added support for parallel index builds - Added `l1_distance` function - Added element-wise multiplication for vectors +- Added `sum` aggregate ## 0.4.4 (2023-06-12) diff --git a/README.md b/README.md index 377a5c2..cd9ef94 100644 --- a/README.md +++ b/README.md @@ -364,6 +364,7 @@ vector_norm(vector) → double precision | Euclidean norm Function | Description --- | --- avg(vector) → vector | arithmetic mean +sum(vector) → vector | sum [unreleased] ## Installation Notes diff --git a/sql/vector--0.4.4--0.5.0.sql b/sql/vector--0.4.4--0.5.0.sql index b8d98be..1521589 100644 --- a/sql/vector--0.4.4--0.5.0.sql +++ b/sql/vector--0.4.4--0.5.0.sql @@ -11,3 +11,15 @@ CREATE OPERATOR * ( LEFTARG = vector, RIGHTARG = vector, PROCEDURE = vector_mul, COMMUTATOR = * ); + +CREATE FUNCTION vector_sum(double precision[]) RETURNS vector + AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + +CREATE AGGREGATE sum(vector) ( + SFUNC = vector_accum, + STYPE = double precision[], + FINALFUNC = vector_sum, + COMBINEFUNC = vector_combine, + INITCOND = '{0}', + PARALLEL = SAFE +); diff --git a/sql/vector.sql b/sql/vector.sql index 9d98d03..55f830b 100644 --- a/sql/vector.sql +++ b/sql/vector.sql @@ -99,6 +99,9 @@ CREATE FUNCTION vector_avg(double precision[]) RETURNS vector CREATE FUNCTION vector_combine(double precision[], double precision[]) RETURNS double precision[] AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; +CREATE FUNCTION vector_sum(double precision[]) RETURNS vector + AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + -- aggregates CREATE AGGREGATE avg(vector) ( @@ -110,6 +113,15 @@ CREATE AGGREGATE avg(vector) ( PARALLEL = SAFE ); +CREATE AGGREGATE sum(vector) ( + SFUNC = vector_accum, + STYPE = double precision[], + FINALFUNC = vector_sum, + COMBINEFUNC = vector_combine, + INITCOND = '{0}', + PARALLEL = SAFE +); + -- cast functions CREATE FUNCTION vector(vector, integer, boolean) RETURNS vector diff --git a/src/vector.c b/src/vector.c index 2ef05c7..8656468 100644 --- a/src/vector.c +++ b/src/vector.c @@ -1077,3 +1077,42 @@ vector_avg(PG_FUNCTION_ARGS) PG_RETURN_POINTER(result); } + +/* + * Sum vectors + */ +PGDLLEXPORT PG_FUNCTION_INFO_V1(vector_sum); +Datum +vector_sum(PG_FUNCTION_ARGS) +{ + ArrayType *statearray = PG_GETARG_ARRAYTYPE_P(0); + float8 *statevalues; + float8 n; + uint16 dim; + Vector *result; + + /* Check array before using */ + statevalues = CheckStateArray(statearray, "vector_sum"); + n = statevalues[0]; + + /* SQL defines AVG of no values to be NULL */ + if (n == 0.0) + PG_RETURN_NULL(); + + /* Create vector */ + dim = STATE_DIMS(statearray); + CheckDim(dim); + result = InitVector(dim); + for (int i = 0; i < dim; i++) + { + result->x[i] = statevalues[i + 1]; + + /* Check for overflow */ + if (isinf(result->x[i])) + float_overflow_error(); + + CheckElement(result->x[i]); + } + + PG_RETURN_POINTER(result); +} diff --git a/test/expected/functions.out b/test/expected/functions.out index 4008697..451762b 100644 --- a/test/expected/functions.out +++ b/test/expected/functions.out @@ -132,3 +132,27 @@ SELECT avg(v) FROM unnest(ARRAY['[1,2]'::vector, '[3]']) v; ERROR: expected 2 dimensions, not 1 SELECT vector_avg(array_agg(n)) FROM generate_series(1, 16002) n; ERROR: vector cannot have more than 16000 dimensions +SELECT sum(v) FROM unnest(ARRAY['[1,2,3]'::vector, '[3,5,7]']) v; + sum +---------- + [4,7,10] +(1 row) + +SELECT sum(v) FROM unnest(ARRAY['[1,2,3]'::vector, '[3,5,7]', NULL]) v; + sum +---------- + [4,7,10] +(1 row) + +SELECT sum(v) FROM unnest(ARRAY[]::vector[]) v; + sum +----- + +(1 row) + +SELECT sum(v) FROM unnest(ARRAY['[1,2]'::vector, '[3]']) v; +ERROR: expected 2 dimensions, not 1 +SELECT sum(v) FROM unnest(ARRAY['[3e38]'::vector, '[3e38]']) v; +ERROR: value out of range: overflow +SELECT vector_sum(array_agg(n)) FROM generate_series(1, 16002) n; +ERROR: vector cannot have more than 16000 dimensions diff --git a/test/sql/functions.sql b/test/sql/functions.sql index 01a9bcc..83ee22e 100644 --- a/test/sql/functions.sql +++ b/test/sql/functions.sql @@ -34,3 +34,10 @@ SELECT avg(v) FROM unnest(ARRAY['[1,2,3]'::vector, '[3,5,7]', NULL]) v; SELECT avg(v) FROM unnest(ARRAY[]::vector[]) v; SELECT avg(v) FROM unnest(ARRAY['[1,2]'::vector, '[3]']) v; SELECT vector_avg(array_agg(n)) FROM generate_series(1, 16002) n; + +SELECT sum(v) FROM unnest(ARRAY['[1,2,3]'::vector, '[3,5,7]']) v; +SELECT sum(v) FROM unnest(ARRAY['[1,2,3]'::vector, '[3,5,7]', NULL]) v; +SELECT sum(v) FROM unnest(ARRAY[]::vector[]) v; +SELECT sum(v) FROM unnest(ARRAY['[1,2]'::vector, '[3]']) v; +SELECT sum(v) FROM unnest(ARRAY['[3e38]'::vector, '[3e38]']) v; +SELECT vector_sum(array_agg(n)) FROM generate_series(1, 16002) n;