From d49d053e84973f1326a9a9486b0af027f75651ec Mon Sep 17 00:00:00 2001 From: Andrew Kane Date: Mon, 25 Mar 2024 21:51:03 -0700 Subject: [PATCH] Added subvector function --- CHANGELOG.md | 4 ++++ sql/vector--0.6.2--0.7.0.sql | 5 +++++ sql/vector.sql | 3 +++ src/vector.c | 31 +++++++++++++++++++++++++++++++ test/expected/functions.out | 28 ++++++++++++++++++++++++++++ test/sql/functions.sql | 7 +++++++ 6 files changed, 78 insertions(+) create mode 100644 sql/vector--0.6.2--0.7.0.sql diff --git a/CHANGELOG.md b/CHANGELOG.md index 9bcea19..2114f89 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,7 @@ +## 0.7.0 (unreleased) + +- Added `subvector` function + ## 0.6.2 (2024-03-18) - Reduced lock contention with parallel HNSW index builds diff --git a/sql/vector--0.6.2--0.7.0.sql b/sql/vector--0.6.2--0.7.0.sql new file mode 100644 index 0000000..4c964ac --- /dev/null +++ b/sql/vector--0.6.2--0.7.0.sql @@ -0,0 +1,5 @@ +-- complain if script is sourced in psql, rather than via CREATE EXTENSION +\echo Use "ALTER EXTENSION vector UPDATE TO '0.7.0'" to load this file. \quit + +CREATE FUNCTION subvector(vector, int, int) RETURNS vector + AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; diff --git a/sql/vector.sql b/sql/vector.sql index 141e83c..5e94aa7 100644 --- a/sql/vector.sql +++ b/sql/vector.sql @@ -58,6 +58,9 @@ CREATE FUNCTION vector_sub(vector, vector) RETURNS vector CREATE FUNCTION vector_mul(vector, vector) RETURNS vector AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; +CREATE FUNCTION subvector(vector, int, int) RETURNS vector + AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + -- private functions CREATE FUNCTION vector_lt(vector, vector) RETURNS bool diff --git a/src/vector.c b/src/vector.c index 5f3cbbb..8236954 100644 --- a/src/vector.c +++ b/src/vector.c @@ -860,6 +860,37 @@ vector_mul(PG_FUNCTION_ARGS) PG_RETURN_POINTER(result); } +/* + * Get a subset of a vector + */ +PGDLLEXPORT PG_FUNCTION_INFO_V1(subvector); +Datum +subvector(PG_FUNCTION_ARGS) +{ + Vector *a = PG_GETARG_VECTOR_P(0); + int32 start = PG_GETARG_INT32(1); + int32 count = PG_GETARG_INT32(2); + int32 end = start + count; + float *ax = a->x; + Vector *result; + int dim; + + if (start < 1) + start = 1; + + if (end > a->dim) + end = a->dim + 1; + + dim = end - start; + CheckDim(dim); + result = InitVector(dim); + + for (int i = 0; i < dim; i++) + result->x[i] = ax[start - 1 + i]; + + PG_RETURN_POINTER(result); +} + /* * Internal helper to compare vectors */ diff --git a/test/expected/functions.out b/test/expected/functions.out index 85d1a2f..1a100d3 100644 --- a/test/expected/functions.out +++ b/test/expected/functions.out @@ -208,6 +208,34 @@ SELECT l1_distance('[3e38]', '[-3e38]'); Infinity (1 row) +SELECT subvector('[1,2,3,4,5]', 1, 3); + subvector +----------- + [1,2,3] +(1 row) + +SELECT subvector('[1,2,3,4,5]', 3, 2); + subvector +----------- + [3,4] +(1 row) + +SELECT subvector('[1,2,3,4,5]', -1, 3); + subvector +----------- + [1] +(1 row) + +SELECT subvector('[1,2,3,4,5]', 3, 9); + subvector +----------- + [3,4,5] +(1 row) + +SELECT subvector('[1,2,3,4,5]', 1, 0); +ERROR: vector must have at least 1 dimension +SELECT subvector('[1,2,3,4,5]', -1, 2); +ERROR: vector must have at least 1 dimension SELECT avg(v) FROM unnest(ARRAY['[1,2,3]'::vector, '[3,5,7]']) v; avg ----------- diff --git a/test/sql/functions.sql b/test/sql/functions.sql index 6235684..a2a8729 100644 --- a/test/sql/functions.sql +++ b/test/sql/functions.sql @@ -48,6 +48,13 @@ SELECT l1_distance('[0,0]', '[0,1]'); SELECT l1_distance('[1,2]', '[3]'); SELECT l1_distance('[3e38]', '[-3e38]'); +SELECT subvector('[1,2,3,4,5]', 1, 3); +SELECT subvector('[1,2,3,4,5]', 3, 2); +SELECT subvector('[1,2,3,4,5]', -1, 3); +SELECT subvector('[1,2,3,4,5]', 3, 9); +SELECT subvector('[1,2,3,4,5]', 1, 0); +SELECT subvector('[1,2,3,4,5]', -1, 2); + SELECT avg(v) FROM unnest(ARRAY['[1,2,3]'::vector, '[3,5,7]']) v; SELECT avg(v) FROM unnest(ARRAY['[1,2,3]'::vector, '[3,5,7]', NULL]) v; SELECT avg(v) FROM unnest(ARRAY[]::vector[]) v;