Added subvector function

This commit is contained in:
Andrew Kane
2024-03-25 21:51:03 -07:00
parent 31e41b3ba9
commit d49d053e84
6 changed files with 78 additions and 0 deletions

View File

@@ -1,3 +1,7 @@
## 0.7.0 (unreleased)
- Added `subvector` function
## 0.6.2 (2024-03-18)
- Reduced lock contention with parallel HNSW index builds

View File

@@ -0,0 +1,5 @@
-- complain if script is sourced in psql, rather than via CREATE EXTENSION
\echo Use "ALTER EXTENSION vector UPDATE TO '0.7.0'" to load this file. \quit
CREATE FUNCTION subvector(vector, int, int) RETURNS vector
AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE;

View File

@@ -58,6 +58,9 @@ CREATE FUNCTION vector_sub(vector, vector) RETURNS vector
CREATE FUNCTION vector_mul(vector, vector) RETURNS vector
AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE;
CREATE FUNCTION subvector(vector, int, int) RETURNS vector
AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE;
-- private functions
CREATE FUNCTION vector_lt(vector, vector) RETURNS bool

View File

@@ -860,6 +860,37 @@ vector_mul(PG_FUNCTION_ARGS)
PG_RETURN_POINTER(result);
}
/*
* Get a subset of a vector
*/
PGDLLEXPORT PG_FUNCTION_INFO_V1(subvector);
Datum
subvector(PG_FUNCTION_ARGS)
{
Vector *a = PG_GETARG_VECTOR_P(0);
int32 start = PG_GETARG_INT32(1);
int32 count = PG_GETARG_INT32(2);
int32 end = start + count;
float *ax = a->x;
Vector *result;
int dim;
if (start < 1)
start = 1;
if (end > a->dim)
end = a->dim + 1;
dim = end - start;
CheckDim(dim);
result = InitVector(dim);
for (int i = 0; i < dim; i++)
result->x[i] = ax[start - 1 + i];
PG_RETURN_POINTER(result);
}
/*
* Internal helper to compare vectors
*/

View File

@@ -208,6 +208,34 @@ SELECT l1_distance('[3e38]', '[-3e38]');
Infinity
(1 row)
SELECT subvector('[1,2,3,4,5]', 1, 3);
subvector
-----------
[1,2,3]
(1 row)
SELECT subvector('[1,2,3,4,5]', 3, 2);
subvector
-----------
[3,4]
(1 row)
SELECT subvector('[1,2,3,4,5]', -1, 3);
subvector
-----------
[1]
(1 row)
SELECT subvector('[1,2,3,4,5]', 3, 9);
subvector
-----------
[3,4,5]
(1 row)
SELECT subvector('[1,2,3,4,5]', 1, 0);
ERROR: vector must have at least 1 dimension
SELECT subvector('[1,2,3,4,5]', -1, 2);
ERROR: vector must have at least 1 dimension
SELECT avg(v) FROM unnest(ARRAY['[1,2,3]'::vector, '[3,5,7]']) v;
avg
-----------

View File

@@ -48,6 +48,13 @@ SELECT l1_distance('[0,0]', '[0,1]');
SELECT l1_distance('[1,2]', '[3]');
SELECT l1_distance('[3e38]', '[-3e38]');
SELECT subvector('[1,2,3,4,5]', 1, 3);
SELECT subvector('[1,2,3,4,5]', 3, 2);
SELECT subvector('[1,2,3,4,5]', -1, 3);
SELECT subvector('[1,2,3,4,5]', 3, 9);
SELECT subvector('[1,2,3,4,5]', 1, 0);
SELECT subvector('[1,2,3,4,5]', -1, 2);
SELECT avg(v) FROM unnest(ARRAY['[1,2,3]'::vector, '[3,5,7]']) v;
SELECT avg(v) FROM unnest(ARRAY['[1,2,3]'::vector, '[3,5,7]', NULL]) v;
SELECT avg(v) FROM unnest(ARRAY[]::vector[]) v;