Added subvector function for sparsevec

This commit is contained in:
Andrew Kane
2024-04-03 15:10:06 -07:00
parent aaa2d644ce
commit 7ca7a64dbb
5 changed files with 97 additions and 0 deletions

View File

@@ -203,6 +203,9 @@ CREATE FUNCTION cosine_distance(sparsevec, sparsevec) RETURNS float8
CREATE FUNCTION sparsevec_norm(sparsevec) RETURNS float8
AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE;
CREATE FUNCTION subvector(sparsevec, int, int) RETURNS sparsevec
AS 'MODULE_PATHNAME', 'sparsevec_subvector' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE;
CREATE FUNCTION sparsevec_l2_squared_distance(sparsevec, sparsevec) RETURNS float8
AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE;

View File

@@ -512,6 +512,9 @@ CREATE FUNCTION cosine_distance(sparsevec, sparsevec) RETURNS float8
CREATE FUNCTION sparsevec_norm(sparsevec) RETURNS float8
AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE;
CREATE FUNCTION subvector(sparsevec, int, int) RETURNS sparsevec
AS 'MODULE_PATHNAME', 'sparsevec_subvector' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE;
-- sparsevec private functions
CREATE FUNCTION sparsevec_l2_squared_distance(sparsevec, sparsevec) RETURNS float8

View File

@@ -776,3 +776,56 @@ sparsevec_norm(PG_FUNCTION_ARGS)
PG_RETURN_FLOAT8(sqrt(norm));
}
/*
* Get a subvector
*/
PGDLLEXPORT PG_FUNCTION_INFO_V1(sparsevec_subvector);
Datum
sparsevec_subvector(PG_FUNCTION_ARGS)
{
SparseVector *a = PG_GETARG_SPARSEVEC_P(0);
int32 start = PG_GETARG_INT32(1);
int32 count = PG_GETARG_INT32(2);
int32 end = start + count;
float *ax = SPARSEVEC_VALUES(a);
SparseVector *result;
float *rx;
int dim;
int nnz = 0;
int startIndex;
/* Indexing starts at 1, like substring */
if (start < 1)
start = 1;
if (end > a->dim)
end = a->dim + 1;
dim = end - start;
CheckDim(dim);
startIndex = dim;
for (startIndex = 0; startIndex < a->nnz; startIndex++)
{
if (a->indices[startIndex] >= start - 1)
break;
}
for (int i = startIndex; i < a->nnz; i++)
{
if (a->indices[i] < end - 1)
nnz++;
}
result = InitSparseVector(dim, nnz);
rx = SPARSEVEC_VALUES(result);
for (int i = 0; i < nnz; i++)
{
result->indices[i] = a->indices[startIndex + i];
rx[i] = ax[startIndex + i];
}
PG_RETURN_POINTER(result);
}

View File

@@ -60,3 +60,33 @@ SELECT cosine_distance('{}/1'::sparsevec, '{}/1');
SELECT cosine_distance('{0:1}/2'::sparsevec, '{0:1}/3');
ERROR: different sparsevec dimensions 2 and 3
SELECT subvector('{0:1,1:2,2:3,3:4,4:5}/5'::sparsevec, 1, 3);
subvector
-----------------
{0:1,1:2,2:3}/3
(1 row)
SELECT subvector('{0:1,1:2,2:3,3:4,4:5}/5'::sparsevec, 3, 2);
subvector
-------------
{2:3,3:4}/2
(1 row)
SELECT subvector('{0:1,1:2,2:3,3:4,4:5}/5'::sparsevec, -1, 3);
subvector
-----------
{0:1}/1
(1 row)
SELECT subvector('{0:1,1:2,2:3,3:4,4:5}/5'::sparsevec, 3, 9);
subvector
-----------------
{2:3,3:4,4:5}/3
(1 row)
SELECT subvector('{0:1,1:2,2:3,3:4,4:5}/5'::sparsevec, 1, 0);
ERROR: sparsevec must have at least 1 dimension
SELECT subvector('{0:1,1:2,2:3,3:4,4:5}/5'::sparsevec, 3, -1);
ERROR: sparsevec must have at least 1 dimension
SELECT subvector('{0:1,1:2,2:3,3:4,4:5}/5'::sparsevec, -1, 2);
ERROR: sparsevec must have at least 1 dimension

View File

@@ -11,3 +11,11 @@ SELECT cosine_distance('{0:1,1:1}/2'::sparsevec, '{0:-1,1:-1}/2');
SELECT cosine_distance('{0:1}/2'::sparsevec, '{1:2}/2');
SELECT cosine_distance('{}/1'::sparsevec, '{}/1');
SELECT cosine_distance('{0:1}/2'::sparsevec, '{0:1}/3');
SELECT subvector('{0:1,1:2,2:3,3:4,4:5}/5'::sparsevec, 1, 3);
SELECT subvector('{0:1,1:2,2:3,3:4,4:5}/5'::sparsevec, 3, 2);
SELECT subvector('{0:1,1:2,2:3,3:4,4:5}/5'::sparsevec, -1, 3);
SELECT subvector('{0:1,1:2,2:3,3:4,4:5}/5'::sparsevec, 3, 9);
SELECT subvector('{0:1,1:2,2:3,3:4,4:5}/5'::sparsevec, 1, 0);
SELECT subvector('{0:1,1:2,2:3,3:4,4:5}/5'::sparsevec, 3, -1);
SELECT subvector('{0:1,1:2,2:3,3:4,4:5}/5'::sparsevec, -1, 2);