From d6044dd423f552c3591380c6fff8aae2591f93fa Mon Sep 17 00:00:00 2001 From: Andrew Kane Date: Tue, 2 Apr 2024 12:13:04 -0700 Subject: [PATCH] Added subvector function --- CHANGELOG.md | 1 + README.md | 1 + sql/vector--0.6.2--0.7.0.sql | 3 +++ sql/vector.sql | 3 +++ src/vector.c | 31 ++++++++++++++++++++++++++++++ test/expected/vector_functions.out | 30 +++++++++++++++++++++++++++++ test/sql/vector_functions.sql | 8 ++++++++ 7 files changed, 77 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 6e57bbf..046f082 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,6 +4,7 @@ - Added `hamming_distance` function - Added `jaccard_distance` function - Added `quantize_binary` function +- Added `subvector` function - Updated comparison operators to support vectors with different dimensions ## 0.6.2 (2024-03-18) diff --git a/README.md b/README.md index 2f02080..140f635 100644 --- a/README.md +++ b/README.md @@ -738,6 +738,7 @@ inner_product(vector, vector) → double precision | inner product | l2_distance(vector, vector) → double precision | Euclidean distance | l1_distance(vector, vector) → double precision | taxicab distance | 0.5.0 quantize_binary(vector) → bit | quantize | unreleased +subvector(vector, integer, integer) → vector | subvector | unreleased vector_dims(vector) → integer | number of dimensions | vector_norm(vector) → double precision | Euclidean norm | diff --git a/sql/vector--0.6.2--0.7.0.sql b/sql/vector--0.6.2--0.7.0.sql index 68409d3..e3f7c20 100644 --- a/sql/vector--0.6.2--0.7.0.sql +++ b/sql/vector--0.6.2--0.7.0.sql @@ -4,6 +4,9 @@ CREATE FUNCTION quantize_binary(vector) RETURNS bit AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; +CREATE FUNCTION subvector(vector, int, int) RETURNS vector + AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + CREATE FUNCTION hamming_distance(bit, bit) RETURNS float8 AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; diff --git a/sql/vector.sql b/sql/vector.sql index 48b91fd..543ed5a 100644 --- a/sql/vector.sql +++ b/sql/vector.sql @@ -61,6 +61,9 @@ CREATE FUNCTION vector_mul(vector, vector) RETURNS vector CREATE FUNCTION quantize_binary(vector) RETURNS bit AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; +CREATE FUNCTION subvector(vector, int, int) RETURNS vector + AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + -- vector private functions CREATE FUNCTION vector_lt(vector, vector) RETURNS bool diff --git a/src/vector.c b/src/vector.c index fd21e44..e678f51 100644 --- a/src/vector.c +++ b/src/vector.c @@ -877,6 +877,37 @@ quantize_binary(PG_FUNCTION_ARGS) PG_RETURN_VARBIT_P(result); } +/* + * Get a subvector + */ +PGDLLEXPORT PG_FUNCTION_INFO_V1(subvector); +Datum +subvector(PG_FUNCTION_ARGS) +{ + Vector *a = PG_GETARG_VECTOR_P(0); + int32 start = PG_GETARG_INT32(1); + int32 count = PG_GETARG_INT32(2); + int32 end = start + count; + float *ax = a->x; + Vector *result; + int dim; + + /* Indexing starts at 1, like substring */ + if (start < 1) + start = 1; + + if (end > a->dim) + end = a->dim + 1; + + dim = end - start; + CheckDim(dim); + result = InitVector(dim); + + for (int i = 0; i < dim; i++) + result->x[i] = ax[start - 1 + i]; + + PG_RETURN_POINTER(result); +} /* * Internal helper to compare vectors diff --git a/test/expected/vector_functions.out b/test/expected/vector_functions.out index c6b594c..8be1525 100644 --- a/test/expected/vector_functions.out +++ b/test/expected/vector_functions.out @@ -284,6 +284,36 @@ SELECT quantize_binary('[0,0.1,-0.2,-0.3,0.4,0.5,0.6,-0.7,0.8,-0.9,1]'); 01001110101 (1 row) +SELECT subvector('[1,2,3,4,5]', 1, 3); + subvector +----------- + [1,2,3] +(1 row) + +SELECT subvector('[1,2,3,4,5]', 3, 2); + subvector +----------- + [3,4] +(1 row) + +SELECT subvector('[1,2,3,4,5]', -1, 3); + subvector +----------- + [1] +(1 row) + +SELECT subvector('[1,2,3,4,5]', 3, 9); + subvector +----------- + [3,4,5] +(1 row) + +SELECT subvector('[1,2,3,4,5]', 1, 0); +ERROR: vector must have at least 1 dimension +SELECT subvector('[1,2,3,4,5]', 3, -1); +ERROR: vector must have at least 1 dimension +SELECT subvector('[1,2,3,4,5]', -1, 2); +ERROR: vector must have at least 1 dimension SELECT avg(v) FROM unnest(ARRAY['[1,2,3]'::vector, '[3,5,7]']) v; avg ----------- diff --git a/test/sql/vector_functions.sql b/test/sql/vector_functions.sql index a91aa48..8e8cc6e 100644 --- a/test/sql/vector_functions.sql +++ b/test/sql/vector_functions.sql @@ -61,6 +61,14 @@ SELECT l1_distance('[3e38]'::vector, '[-3e38]'); SELECT quantize_binary('[1,0,-1]'); SELECT quantize_binary('[0,0.1,-0.2,-0.3,0.4,0.5,0.6,-0.7,0.8,-0.9,1]'); +SELECT subvector('[1,2,3,4,5]', 1, 3); +SELECT subvector('[1,2,3,4,5]', 3, 2); +SELECT subvector('[1,2,3,4,5]', -1, 3); +SELECT subvector('[1,2,3,4,5]', 3, 9); +SELECT subvector('[1,2,3,4,5]', 1, 0); +SELECT subvector('[1,2,3,4,5]', 3, -1); +SELECT subvector('[1,2,3,4,5]', -1, 2); + SELECT avg(v) FROM unnest(ARRAY['[1,2,3]'::vector, '[3,5,7]']) v; SELECT avg(v) FROM unnest(ARRAY['[1,2,3]'::vector, '[3,5,7]', NULL]) v; SELECT avg(v) FROM unnest(ARRAY[]::vector[]) v;