diff --git a/sql/vector--0.6.2--0.7.0.sql b/sql/vector--0.6.2--0.7.0.sql index 3fdc8f1..3d45d49 100644 --- a/sql/vector--0.6.2--0.7.0.sql +++ b/sql/vector--0.6.2--0.7.0.sql @@ -203,6 +203,9 @@ CREATE FUNCTION cosine_distance(sparsevec, sparsevec) RETURNS float8 CREATE FUNCTION sparsevec_norm(sparsevec) RETURNS float8 AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; +CREATE FUNCTION subvector(sparsevec, int, int) RETURNS sparsevec + AS 'MODULE_PATHNAME', 'sparsevec_subvector' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + CREATE FUNCTION sparsevec_l2_squared_distance(sparsevec, sparsevec) RETURNS float8 AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; diff --git a/sql/vector.sql b/sql/vector.sql index c694b1c..06c831d 100644 --- a/sql/vector.sql +++ b/sql/vector.sql @@ -512,6 +512,9 @@ CREATE FUNCTION cosine_distance(sparsevec, sparsevec) RETURNS float8 CREATE FUNCTION sparsevec_norm(sparsevec) RETURNS float8 AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; +CREATE FUNCTION subvector(sparsevec, int, int) RETURNS sparsevec + AS 'MODULE_PATHNAME', 'sparsevec_subvector' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + -- sparsevec private functions CREATE FUNCTION sparsevec_l2_squared_distance(sparsevec, sparsevec) RETURNS float8 diff --git a/src/sparsevec.c b/src/sparsevec.c index 22649cd..a288297 100644 --- a/src/sparsevec.c +++ b/src/sparsevec.c @@ -776,3 +776,56 @@ sparsevec_norm(PG_FUNCTION_ARGS) PG_RETURN_FLOAT8(sqrt(norm)); } + +/* + * Get a subvector + */ +PGDLLEXPORT PG_FUNCTION_INFO_V1(sparsevec_subvector); +Datum +sparsevec_subvector(PG_FUNCTION_ARGS) +{ + SparseVector *a = PG_GETARG_SPARSEVEC_P(0); + int32 start = PG_GETARG_INT32(1); + int32 count = PG_GETARG_INT32(2); + int32 end = start + count; + float *ax = SPARSEVEC_VALUES(a); + SparseVector *result; + float *rx; + int dim; + int nnz = 0; + int startIndex; + + /* Indexing starts at 1, like substring */ + if (start < 1) + start = 1; + + if (end > a->dim) + end = a->dim + 1; + + dim = end - start; + CheckDim(dim); + startIndex = dim; + + for (startIndex = 0; startIndex < a->nnz; startIndex++) + { + if (a->indices[startIndex] >= start - 1) + break; + } + + for (int i = startIndex; i < a->nnz; i++) + { + if (a->indices[i] < end - 1) + nnz++; + } + + result = InitSparseVector(dim, nnz); + rx = SPARSEVEC_VALUES(result); + + for (int i = 0; i < nnz; i++) + { + result->indices[i] = a->indices[startIndex + i]; + rx[i] = ax[startIndex + i]; + } + + PG_RETURN_POINTER(result); +} diff --git a/test/expected/sparsevec_functions.out b/test/expected/sparsevec_functions.out index 07117d8..2f2f4d0 100644 --- a/test/expected/sparsevec_functions.out +++ b/test/expected/sparsevec_functions.out @@ -60,3 +60,33 @@ SELECT cosine_distance('{}/1'::sparsevec, '{}/1'); SELECT cosine_distance('{0:1}/2'::sparsevec, '{0:1}/3'); ERROR: different sparsevec dimensions 2 and 3 +SELECT subvector('{0:1,1:2,2:3,3:4,4:5}/5'::sparsevec, 1, 3); + subvector +----------------- + {0:1,1:2,2:3}/3 +(1 row) + +SELECT subvector('{0:1,1:2,2:3,3:4,4:5}/5'::sparsevec, 3, 2); + subvector +------------- + {2:3,3:4}/2 +(1 row) + +SELECT subvector('{0:1,1:2,2:3,3:4,4:5}/5'::sparsevec, -1, 3); + subvector +----------- + {0:1}/1 +(1 row) + +SELECT subvector('{0:1,1:2,2:3,3:4,4:5}/5'::sparsevec, 3, 9); + subvector +----------------- + {2:3,3:4,4:5}/3 +(1 row) + +SELECT subvector('{0:1,1:2,2:3,3:4,4:5}/5'::sparsevec, 1, 0); +ERROR: sparsevec must have at least 1 dimension +SELECT subvector('{0:1,1:2,2:3,3:4,4:5}/5'::sparsevec, 3, -1); +ERROR: sparsevec must have at least 1 dimension +SELECT subvector('{0:1,1:2,2:3,3:4,4:5}/5'::sparsevec, -1, 2); +ERROR: sparsevec must have at least 1 dimension diff --git a/test/sql/sparsevec_functions.sql b/test/sql/sparsevec_functions.sql index 86f7990..cf9bef8 100644 --- a/test/sql/sparsevec_functions.sql +++ b/test/sql/sparsevec_functions.sql @@ -11,3 +11,11 @@ SELECT cosine_distance('{0:1,1:1}/2'::sparsevec, '{0:-1,1:-1}/2'); SELECT cosine_distance('{0:1}/2'::sparsevec, '{1:2}/2'); SELECT cosine_distance('{}/1'::sparsevec, '{}/1'); SELECT cosine_distance('{0:1}/2'::sparsevec, '{0:1}/3'); + +SELECT subvector('{0:1,1:2,2:3,3:4,4:5}/5'::sparsevec, 1, 3); +SELECT subvector('{0:1,1:2,2:3,3:4,4:5}/5'::sparsevec, 3, 2); +SELECT subvector('{0:1,1:2,2:3,3:4,4:5}/5'::sparsevec, -1, 3); +SELECT subvector('{0:1,1:2,2:3,3:4,4:5}/5'::sparsevec, 3, 9); +SELECT subvector('{0:1,1:2,2:3,3:4,4:5}/5'::sparsevec, 1, 0); +SELECT subvector('{0:1,1:2,2:3,3:4,4:5}/5'::sparsevec, 3, -1); +SELECT subvector('{0:1,1:2,2:3,3:4,4:5}/5'::sparsevec, -1, 2);