From 08e7209810da8023bd9cdc52e181edc59fc3ccde Mon Sep 17 00:00:00 2001 From: Andrew Kane Date: Sat, 15 Jul 2023 20:19:51 -0700 Subject: [PATCH] Added element-wise multiplication for vectors --- CHANGELOG.md | 1 + README.md | 1 + sql/vector--0.4.4--0.5.0.sql | 8 +++++++ sql/vector.sql | 8 +++++++ src/vector.c | 44 ++++++++++++++++++++++++++++++++++++ test/expected/functions.out | 10 ++++++++ test/sql/functions.sql | 3 +++ 7 files changed, 75 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index f07f466..bfa60e1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,7 @@ - Added support for parallel index builds - Added `l1_distance` function +- Added element-wise multiplication for vectors ## 0.4.4 (2023-06-12) diff --git a/README.md b/README.md index a1d9f6a..377a5c2 100644 --- a/README.md +++ b/README.md @@ -343,6 +343,7 @@ Operator | Description --- | --- \+ | element-wise addition \- | element-wise subtraction +\* | element-wise multiplication [unreleased] <-> | Euclidean distance <#> | negative inner product <=> | cosine distance diff --git a/sql/vector--0.4.4--0.5.0.sql b/sql/vector--0.4.4--0.5.0.sql index 25a38f8..b8d98be 100644 --- a/sql/vector--0.4.4--0.5.0.sql +++ b/sql/vector--0.4.4--0.5.0.sql @@ -3,3 +3,11 @@ CREATE FUNCTION l1_distance(vector, vector) RETURNS float8 AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + +CREATE FUNCTION vector_mul(vector, vector) RETURNS vector + AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + +CREATE OPERATOR * ( + LEFTARG = vector, RIGHTARG = vector, PROCEDURE = vector_mul, + COMMUTATOR = * +); diff --git a/sql/vector.sql b/sql/vector.sql index d4cf04d..9d98d03 100644 --- a/sql/vector.sql +++ b/sql/vector.sql @@ -55,6 +55,9 @@ CREATE FUNCTION vector_add(vector, vector) RETURNS vector CREATE FUNCTION vector_sub(vector, vector) RETURNS vector AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; +CREATE FUNCTION vector_mul(vector, vector) RETURNS vector + AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + -- private functions CREATE FUNCTION vector_lt(vector, vector) RETURNS bool @@ -174,6 +177,11 @@ CREATE OPERATOR - ( COMMUTATOR = - ); +CREATE OPERATOR * ( + LEFTARG = vector, RIGHTARG = vector, PROCEDURE = vector_mul, + COMMUTATOR = * +); + CREATE OPERATOR < ( LEFTARG = vector, RIGHTARG = vector, PROCEDURE = vector_lt, COMMUTATOR = > , NEGATOR = >= , diff --git a/src/vector.c b/src/vector.c index 3b383be..2ef05c7 100644 --- a/src/vector.c +++ b/src/vector.c @@ -125,6 +125,14 @@ float_overflow_error(void) (errcode(ERRCODE_NUMERIC_VALUE_OUT_OF_RANGE), errmsg("value out of range: overflow"))); } + +static pg_noinline void +float_underflow_error(void) +{ + ereport(ERROR, + (errcode(ERRCODE_NUMERIC_VALUE_OUT_OF_RANGE), + errmsg("value out of range: underflow"))); +} #endif /* @@ -759,6 +767,42 @@ vector_sub(PG_FUNCTION_ARGS) PG_RETURN_POINTER(result); } +/* + * Multiply vectors + */ +PGDLLEXPORT PG_FUNCTION_INFO_V1(vector_mul); +Datum +vector_mul(PG_FUNCTION_ARGS) +{ + Vector *a = PG_GETARG_VECTOR_P(0); + Vector *b = PG_GETARG_VECTOR_P(1); + float *ax = a->x; + float *bx = b->x; + Vector *result; + float *rx; + + CheckDims(a, b); + + result = InitVector(a->dim); + rx = result->x; + + /* Auto-vectorized */ + for (int i = 0, imax = a->dim; i < imax; i++) + rx[i] = ax[i] * bx[i]; + + /* Check for overflow and underflow */ + for (int i = 0, imax = a->dim; i < imax; i++) + { + if (isinf(rx[i])) + float_overflow_error(); + + if (rx[i] == 0 && !(ax[i] == 0 || bx[i] == 0)) + float_underflow_error(); + } + + PG_RETURN_POINTER(result); +} + /* * Internal helper to compare vectors */ diff --git a/test/expected/functions.out b/test/expected/functions.out index 1a8a471..4008697 100644 --- a/test/expected/functions.out +++ b/test/expected/functions.out @@ -14,6 +14,16 @@ SELECT '[1,2,3]'::vector - '[4,5,6]'; SELECT '[-3e38]'::vector - '[3e38]'; ERROR: value out of range: overflow +SELECT '[1,2,3]'::vector * '[4,5,6]'; + ?column? +----------- + [4,10,18] +(1 row) + +SELECT '[1e37]'::vector * '[1e37]'; +ERROR: value out of range: overflow +SELECT '[1e-37]'::vector * '[1e-37]'; +ERROR: value out of range: underflow SELECT vector_dims('[1,2,3]'); vector_dims ------------- diff --git a/test/sql/functions.sql b/test/sql/functions.sql index 324debb..01a9bcc 100644 --- a/test/sql/functions.sql +++ b/test/sql/functions.sql @@ -2,6 +2,9 @@ SELECT '[1,2,3]'::vector + '[4,5,6]'; SELECT '[3e38]'::vector + '[3e38]'; SELECT '[1,2,3]'::vector - '[4,5,6]'; SELECT '[-3e38]'::vector - '[3e38]'; +SELECT '[1,2,3]'::vector * '[4,5,6]'; +SELECT '[1e37]'::vector * '[1e37]'; +SELECT '[1e-37]'::vector * '[1e-37]'; SELECT vector_dims('[1,2,3]');