From e146f3cfb6b540714ef9d9cf04b8bdc39b6fe5a6 Mon Sep 17 00:00:00 2001 From: Andrew Kane Date: Sun, 14 Apr 2024 15:11:11 -0700 Subject: [PATCH] Added avg for half vectors [skip ci] --- README.md | 6 ++ sql/vector--0.6.2--0.7.0.sql | 15 ++++ sql/vector.sql | 17 +++++ src/halfvec.c | 112 ++++++++++++++++++++++++++++ src/vector.c | 2 +- test/expected/halfvec_functions.out | 28 +++++++ test/sql/halfvec_functions.sql | 7 ++ 7 files changed, 186 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index 2b13bc1..3ce6e7f 100644 --- a/README.md +++ b/README.md @@ -902,6 +902,12 @@ l2_distance(halfvec, halfvec) → double precision | Euclidean distance | unrele quantize_binary(halfvec) → bit | quantize | unreleased subvector(halfvec, integer, integer) → halfvec | subvector | unreleased +### Halfvec Aggregate Functions + +Function | Description | Added +--- | --- | --- +avg(halfvec) → halfvec | average | unreleased + ### Bit Type Each bit vector takes `dimensions / 8 + 8` bytes of storage. See the [Postgres docs](https://www.postgresql.org/docs/current/datatype-bit.html) for more info. diff --git a/sql/vector--0.6.2--0.7.0.sql b/sql/vector--0.6.2--0.7.0.sql index 25a97d1..5be4006 100644 --- a/sql/vector--0.6.2--0.7.0.sql +++ b/sql/vector--0.6.2--0.7.0.sql @@ -122,6 +122,21 @@ CREATE FUNCTION halfvec_negative_inner_product(halfvec, halfvec) RETURNS float8 CREATE FUNCTION halfvec_spherical_distance(halfvec, halfvec) RETURNS float8 AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; +CREATE FUNCTION halfvec_accum(double precision[], halfvec) RETURNS double precision[] + AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + +CREATE FUNCTION halfvec_avg(double precision[]) RETURNS halfvec + AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + +CREATE AGGREGATE avg(halfvec) ( + SFUNC = halfvec_accum, + STYPE = double precision[], + FINALFUNC = halfvec_avg, + COMBINEFUNC = vector_combine, + INITCOND = '{0}', + PARALLEL = SAFE +); + CREATE FUNCTION halfvec(halfvec, integer, boolean) RETURNS halfvec AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; diff --git a/sql/vector.sql b/sql/vector.sql index 501eb7c..900a23c 100644 --- a/sql/vector.sql +++ b/sql/vector.sql @@ -417,6 +417,23 @@ CREATE FUNCTION halfvec_negative_inner_product(halfvec, halfvec) RETURNS float8 CREATE FUNCTION halfvec_spherical_distance(halfvec, halfvec) RETURNS float8 AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; +CREATE FUNCTION halfvec_accum(double precision[], halfvec) RETURNS double precision[] + AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + +CREATE FUNCTION halfvec_avg(double precision[]) RETURNS halfvec + AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + +-- halfvec aggregates + +CREATE AGGREGATE avg(halfvec) ( + SFUNC = halfvec_accum, + STYPE = double precision[], + FINALFUNC = halfvec_avg, + COMBINEFUNC = vector_combine, + INITCOND = '{0}', + PARALLEL = SAFE +); + -- halfvec cast functions CREATE FUNCTION halfvec(halfvec, integer, boolean) RETURNS halfvec diff --git a/src/halfvec.c b/src/halfvec.c index 725960e..5e12712 100644 --- a/src/halfvec.c +++ b/src/halfvec.c @@ -23,6 +23,9 @@ #define TYPALIGN_INT 'i' #endif +#define STATE_DIMS(x) (ARR_DIMS(x)[0] - 1) +#define CreateStateDatums(dim) palloc(sizeof(Datum) * (dim + 1)) + /* * Get a half from a message buffer */ @@ -146,6 +149,20 @@ halfvec_isspace(char ch) return false; } +/* + * Check state array + */ +static float8 * +CheckStateArray(ArrayType *statearray, const char *caller) +{ + if (ARR_NDIM(statearray) != 1 || + ARR_DIMS(statearray)[0] < 1 || + ARR_HASNULL(statearray) || + ARR_ELEMTYPE(statearray) != FLOAT8OID) + elog(ERROR, "%s: expected state array", caller); + return (float8 *) ARR_DATA_PTR(statearray); +} + #if PG_VERSION_NUM < 120003 static pg_noinline void float_overflow_error(void) @@ -1016,3 +1033,98 @@ halfvec_cmp(PG_FUNCTION_ARGS) PG_RETURN_INT32(halfvec_cmp_internal(a, b)); } + +/* + * Accumulate half vectors + */ +PGDLLEXPORT PG_FUNCTION_INFO_V1(halfvec_accum); +Datum +halfvec_accum(PG_FUNCTION_ARGS) +{ + ArrayType *statearray = PG_GETARG_ARRAYTYPE_P(0); + HalfVector *newval = PG_GETARG_HALFVEC_P(1); + float8 *statevalues; + int16 dim; + bool newarr; + float8 n; + Datum *statedatums; + half *x = newval->x; + ArrayType *result; + + /* Check array before using */ + statevalues = CheckStateArray(statearray, "halfvec_accum"); + dim = STATE_DIMS(statearray); + newarr = dim == 0; + + if (newarr) + dim = newval->dim; + else + CheckExpectedDim(dim, newval->dim); + + n = statevalues[0] + 1.0; + + statedatums = CreateStateDatums(dim); + statedatums[0] = Float8GetDatum(n); + + if (newarr) + { + for (int i = 0; i < dim; i++) + statedatums[i + 1] = Float8GetDatum((double) HalfToFloat4(x[i])); + } + else + { + for (int i = 0; i < dim; i++) + { + double v = statevalues[i + 1] + (double) HalfToFloat4(x[i]); + + /* Check for overflow */ + if (isinf(v)) + float_overflow_error(); + + statedatums[i + 1] = Float8GetDatum(v); + } + } + + /* Use float8 array like float4_accum */ + result = construct_array(statedatums, dim + 1, + FLOAT8OID, + sizeof(float8), FLOAT8PASSBYVAL, TYPALIGN_DOUBLE); + + pfree(statedatums); + + PG_RETURN_ARRAYTYPE_P(result); +} + +/* + * Average half vectors + */ +PGDLLEXPORT PG_FUNCTION_INFO_V1(halfvec_avg); +Datum +halfvec_avg(PG_FUNCTION_ARGS) +{ + ArrayType *statearray = PG_GETARG_ARRAYTYPE_P(0); + float8 *statevalues; + float8 n; + uint16 dim; + HalfVector *result; + + /* Check array before using */ + statevalues = CheckStateArray(statearray, "halfvec_avg"); + n = statevalues[0]; + + /* SQL defines AVG of no values to be NULL */ + if (n == 0.0) + PG_RETURN_NULL(); + + /* Create half vector */ + dim = STATE_DIMS(statearray); + CheckDim(dim); + result = InitHalfVector(dim); + for (int i = 0; i < dim; i++) + { + result->x[i] = Float4ToHalf(statevalues[i + 1] / n); + CheckElement(result->x[i]); + } + + PG_RETURN_POINTER(result); +} diff --git a/src/vector.c b/src/vector.c index b02a4a3..d823e4f 100644 --- a/src/vector.c +++ b/src/vector.c @@ -1100,7 +1100,7 @@ vector_accum(PG_FUNCTION_ARGS) } /* - * Combine vectors + * Combine vectors or half vectors */ PGDLLEXPORT PG_FUNCTION_INFO_V1(vector_combine); Datum diff --git a/test/expected/halfvec_functions.out b/test/expected/halfvec_functions.out index 70d904a..a71c79e 100644 --- a/test/expected/halfvec_functions.out +++ b/test/expected/halfvec_functions.out @@ -320,3 +320,31 @@ SELECT subvector('[1,2,3,4,5]'::halfvec, 3, -1); ERROR: halfvec must have at least 1 dimension SELECT subvector('[1,2,3,4,5]'::halfvec, -1, 2); ERROR: halfvec must have at least 1 dimension +SELECT avg(v) FROM unnest(ARRAY['[1,2,3]'::halfvec, '[3,5,7]']) v; + avg +----------- + [2,3.5,5] +(1 row) + +SELECT avg(v) FROM unnest(ARRAY['[1,2,3]'::halfvec, '[3,5,7]', NULL]) v; + avg +----------- + [2,3.5,5] +(1 row) + +SELECT avg(v) FROM unnest(ARRAY[]::halfvec[]) v; + avg +----- + +(1 row) + +SELECT avg(v) FROM unnest(ARRAY['[1,2]'::halfvec, '[3]']) v; +ERROR: expected 2 dimensions, not 1 +SELECT avg(v) FROM unnest(ARRAY['[65504]'::halfvec, '[65504]']) v; + avg +--------- + [65504] +(1 row) + +SELECT halfvec_avg(array_agg(n)) FROM generate_series(1, 16002) n; +ERROR: halfvec cannot have more than 16000 dimensions diff --git a/test/sql/halfvec_functions.sql b/test/sql/halfvec_functions.sql index 9f4943b..d5cbcbb 100644 --- a/test/sql/halfvec_functions.sql +++ b/test/sql/halfvec_functions.sql @@ -69,3 +69,10 @@ SELECT subvector('[1,2,3,4,5]'::halfvec, 3, 9); SELECT subvector('[1,2,3,4,5]'::halfvec, 1, 0); SELECT subvector('[1,2,3,4,5]'::halfvec, 3, -1); SELECT subvector('[1,2,3,4,5]'::halfvec, -1, 2); + +SELECT avg(v) FROM unnest(ARRAY['[1,2,3]'::halfvec, '[3,5,7]']) v; +SELECT avg(v) FROM unnest(ARRAY['[1,2,3]'::halfvec, '[3,5,7]', NULL]) v; +SELECT avg(v) FROM unnest(ARRAY[]::halfvec[]) v; +SELECT avg(v) FROM unnest(ARRAY['[1,2]'::halfvec, '[3]']) v; +SELECT avg(v) FROM unnest(ARRAY['[65504]'::halfvec, '[65504]']) v; +SELECT halfvec_avg(array_agg(n)) FROM generate_series(1, 16002) n;