diff --git a/CHANGELOG.md b/CHANGELOG.md index 09cdc5b..33091a0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -3,6 +3,7 @@ - Changed text representation for vector elements to match `real` - Changed storage for vector from `plain` to `extended` for new installations - Improved accuracy of text parsing for certain inputs +- Added `avg` aggregate for vector - Added experimental support for Windows - Dropped support for Postgres 10 diff --git a/sql/vector--0.3.2--0.4.0.sql b/sql/vector--0.3.2--0.4.0.sql index e095e4e..278c7d0 100644 --- a/sql/vector--0.3.2--0.4.0.sql +++ b/sql/vector--0.3.2--0.4.0.sql @@ -3,3 +3,21 @@ -- requires Postgres 13+ -- ALTER TYPE vector SET (STORAGE = extended); + +CREATE FUNCTION vector_accum(double precision[], vector) RETURNS double precision[] + AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + +CREATE FUNCTION vector_avg(double precision[]) RETURNS vector + AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + +CREATE FUNCTION vector_combine(double precision[], double precision[]) RETURNS double precision[] + AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + +CREATE AGGREGATE avg(vector) ( + SFUNC = vector_accum, + STYPE = double precision[], + FINALFUNC = vector_avg, + COMBINEFUNC = vector_combine, + INITCOND = '{0}', + PARALLEL = SAFE +); diff --git a/sql/vector.sql b/sql/vector.sql index 20f50be..6188e2e 100644 --- a/sql/vector.sql +++ b/sql/vector.sql @@ -84,6 +84,26 @@ CREATE FUNCTION vector_negative_inner_product(vector, vector) RETURNS float8 CREATE FUNCTION vector_spherical_distance(vector, vector) RETURNS float8 AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; +CREATE FUNCTION vector_accum(double precision[], vector) RETURNS double precision[] + AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + +CREATE FUNCTION vector_avg(double precision[]) RETURNS vector + AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + +CREATE FUNCTION vector_combine(double precision[], double precision[]) RETURNS double precision[] + AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + +-- aggregates + +CREATE AGGREGATE avg(vector) ( + SFUNC = vector_accum, + STYPE = double precision[], + FINALFUNC = vector_avg, + COMBINEFUNC = vector_combine, + INITCOND = '{0}', + PARALLEL = SAFE +); + -- cast functions CREATE FUNCTION vector(vector, integer, boolean) RETURNS vector diff --git a/src/vector.c b/src/vector.c index 22a9853..4b1171d 100644 --- a/src/vector.c +++ b/src/vector.c @@ -23,6 +23,9 @@ #define TYPALIGN_INT 'i' #endif +#define STATE_DIMS(x) (ARR_DIMS(x)[0] - 1) +#define CreateStateDatums(dim) palloc(sizeof(Datum) * (dim + 1)) + PG_MODULE_MAGIC; /* @@ -82,6 +85,20 @@ CheckElement(float value) errmsg("infinite value not allowed in vector"))); } +/* + * Check state array + */ +static float8 * +CheckStateArray(ArrayType *statearray, const char *caller) +{ + if (ARR_NDIM(statearray) != 1 || + ARR_DIMS(statearray)[0] < 1 || + ARR_HASNULL(statearray) || + ARR_ELEMTYPE(statearray) != FLOAT8OID) + elog(ERROR, "%s: expected state array", caller); + return (float8 *) ARR_DATA_PTR(statearray); +} + /* * Print vector - useful for debugging */ @@ -758,3 +775,167 @@ vector_cmp(PG_FUNCTION_ARGS) PG_RETURN_INT32(vector_cmp_internal(a, b)); } + +/* + * Accumulate vectors + */ +PGDLLEXPORT PG_FUNCTION_INFO_V1(vector_accum); +Datum +vector_accum(PG_FUNCTION_ARGS) +{ + ArrayType *statearray = PG_GETARG_ARRAYTYPE_P(0); + Vector *newval = PG_GETARG_VECTOR_P(1); + float8 *statevalues; + int16 dim; + bool newarr; + float8 n; + Datum *statedatums; + float *x = newval->x; + ArrayType *result; + + /* Check array before using */ + statevalues = CheckStateArray(statearray, "vector_accum"); + dim = STATE_DIMS(statearray); + newarr = dim == 0; + + if (newarr) + dim = newval->dim; + else + CheckExpectedDim(dim, newval->dim); + + n = statevalues[0] + 1.0; + + statedatums = CreateStateDatums(dim); + statedatums[0] = Float8GetDatumFast(n); + + if (newarr) + { + for (int i = 0; i < dim; i++) + statedatums[i + 1] = Float8GetDatumFast(x[i]); + } + else + { + for (int i = 0; i < dim; i++) + { + double v = statevalues[i + 1] + x[i]; + + if (isinf(v)) + float_overflow_error(); + + statedatums[i + 1] = Float8GetDatumFast(v); + } + } + + /* Use float8 array like float4_accum */ + result = construct_array(statedatums, dim + 1, + FLOAT8OID, + sizeof(float8), FLOAT8PASSBYVAL, TYPALIGN_DOUBLE); + + pfree(statedatums); + + PG_RETURN_ARRAYTYPE_P(result); +} + +/* + * Combine vectors + */ +PGDLLEXPORT PG_FUNCTION_INFO_V1(vector_combine); +Datum +vector_combine(PG_FUNCTION_ARGS) +{ + ArrayType *statearray1 = PG_GETARG_ARRAYTYPE_P(0); + ArrayType *statearray2 = PG_GETARG_ARRAYTYPE_P(1); + float8 *statevalues1; + float8 *statevalues2; + float8 n; + float8 n1; + float8 n2; + int16 dim; + Datum *statedatums; + ArrayType *result; + + /* Check arrays before using */ + statevalues1 = CheckStateArray(statearray1, "vector_combine"); + statevalues2 = CheckStateArray(statearray2, "vector_combine"); + + n1 = statevalues1[0]; + n2 = statevalues2[0]; + + if (n1 == 0.0) + { + n = n2; + dim = STATE_DIMS(statearray2); + statedatums = CreateStateDatums(dim); + for (int i = 1; i <= dim; i++) + statedatums[i] = Float8GetDatumFast(statevalues2[i]); + } + else if (n2 == 0.0) + { + n = n1; + dim = STATE_DIMS(statearray1); + statedatums = CreateStateDatums(dim); + for (int i = 1; i <= dim; i++) + statedatums[i] = Float8GetDatumFast(statevalues1[i]); + } + else + { + n = n1 + n2; + dim = STATE_DIMS(statearray1); + CheckExpectedDim(dim, STATE_DIMS(statearray2)); + statedatums = CreateStateDatums(dim); + for (int i = 1; i <= dim; i++) + { + double v = statevalues1[i] + statevalues2[i]; + + if (isinf(v)) + float_overflow_error(); + + statedatums[i] = Float8GetDatumFast(v); + } + } + + statedatums[0] = Float8GetDatumFast(n); + + result = construct_array(statedatums, dim + 1, + FLOAT8OID, + sizeof(float8), FLOAT8PASSBYVAL, TYPALIGN_DOUBLE); + + pfree(statedatums); + + PG_RETURN_ARRAYTYPE_P(result); +} + +/* + * Average vectors + */ +PGDLLEXPORT PG_FUNCTION_INFO_V1(vector_avg); +Datum +vector_avg(PG_FUNCTION_ARGS) +{ + ArrayType *statearray = PG_GETARG_ARRAYTYPE_P(0); + float8 *statevalues; + float8 n; + uint16 dim; + Vector *result; + float v; + + /* Check array before using */ + statevalues = CheckStateArray(statearray, "vector_avg"); + n = statevalues[0]; + + /* SQL defines AVG of no values to be NULL */ + if (n == 0.0) + PG_RETURN_NULL(); + + /* Create vector */ + dim = STATE_DIMS(statearray); + result = InitVector(dim); + for (int i = 0; i < dim; i++) + { + v = statevalues[i + 1] / n; + CheckElement(v); + result->x[i] = v; + } + + PG_RETURN_POINTER(result); +} diff --git a/test/expected/functions.out b/test/expected/functions.out index a568e32..8f81c74 100644 --- a/test/expected/functions.out +++ b/test/expected/functions.out @@ -52,3 +52,23 @@ SELECT cosine_distance('[1,2]', '[0,0]'); SELECT cosine_distance('[1,2]', '[3]'); ERROR: different vector dimensions 2 and 1 +SELECT avg(v) FROM unnest(ARRAY['[1,2,3]'::vector, '[3,5,7]']) v; + avg +----------- + [2,3.5,5] +(1 row) + +SELECT avg(v) FROM unnest(ARRAY['[1,2,3]'::vector, '[3,5,7]', NULL]) v; + avg +----------- + [2,3.5,5] +(1 row) + +SELECT avg(v) FROM unnest(ARRAY[]::vector[]) v; + avg +----- + +(1 row) + +SELECT avg(v) FROM unnest(ARRAY['[1,2]'::vector, '[3]']) v; +ERROR: expected 2 dimensions, not 1 diff --git a/test/sql/functions.sql b/test/sql/functions.sql index 2a1b631..3c6e949 100644 --- a/test/sql/functions.sql +++ b/test/sql/functions.sql @@ -13,3 +13,8 @@ SELECT inner_product('[1,2]', '[3]'); SELECT round(cosine_distance('[1,2]', '[2,4]')::numeric, 5); SELECT cosine_distance('[1,2]', '[0,0]'); SELECT cosine_distance('[1,2]', '[3]'); + +SELECT avg(v) FROM unnest(ARRAY['[1,2,3]'::vector, '[3,5,7]']) v; +SELECT avg(v) FROM unnest(ARRAY['[1,2,3]'::vector, '[3,5,7]', NULL]) v; +SELECT avg(v) FROM unnest(ARRAY[]::vector[]) v; +SELECT avg(v) FROM unnest(ARRAY['[1,2]'::vector, '[3]']) v; diff --git a/test/t/008_avg.pl b/test/t/008_avg.pl new file mode 100644 index 0000000..0cfc273 --- /dev/null +++ b/test/t/008_avg.pl @@ -0,0 +1,27 @@ +use strict; +use warnings; +use PostgresNode; +use TestLib; +use Test::More tests => 4; + +# Initialize node +my $node = get_new_node('node'); +$node->init; +$node->start; + +# Create table +$node->safe_psql("postgres", "CREATE EXTENSION vector;"); +$node->safe_psql("postgres", "CREATE TABLE tst (v vector(3));"); +$node->safe_psql("postgres", + "INSERT INTO tst SELECT ARRAY[1.01 + random(), 2.01 + random(), 3.01 + random()] FROM generate_series(1, 1000000) i;" +); + +# Test avg +my $avg = $node->safe_psql("postgres", "SELECT AVG(v) FROM tst;"); +like($avg, qr/\[1\.5/); +like($avg, qr/,2\.5/); +like($avg, qr/,3\.5/); + +# Test explain +my $explain = $node->safe_psql("postgres", "EXPLAIN SELECT AVG(v) FROM tst;"); +like($explain, qr/Partial Aggregate/);