Added comparison operators for sparsevec

This commit is contained in:
Andrew Kane
2024-04-14 13:40:37 -07:00
parent c68c2867fd
commit 88788472ba
5 changed files with 383 additions and 2 deletions

View File

@@ -798,3 +798,135 @@ sparsevec_norm(PG_FUNCTION_ARGS)
PG_RETURN_FLOAT8(sqrt(norm));
}
/*
* Internal helper to compare sparse vectors
*/
static int
sparsevec_cmp_internal(SparseVector * a, SparseVector * b)
{
float *ax = SPARSEVEC_VALUES(a);
float *bx = SPARSEVEC_VALUES(b);
int nnz = Min(a->nnz, b->nnz);
/* Check values before dimensions to be consistent with Postgres arrays */
for (int i = 0; i < nnz; i++)
{
if (a->indices[i] < b->indices[i])
return ax[i] < 0 ? -1 : 1;
if (a->indices[i] > b->indices[i])
return bx[i] < 0 ? 1 : -1;
if (ax[i] < bx[i])
return -1;
if (ax[i] > bx[i])
return 1;
}
if (a->nnz < b->nnz)
return bx[nnz + 1] < 0 ? 1 : -1;
if (a->nnz > b->nnz)
return ax[nnz + 1] < 0 ? -1 : 1;
if (a->dim < b->dim)
return -1;
if (a->dim > b->dim)
return 1;
return 0;
}
/*
* Less than
*/
PGDLLEXPORT PG_FUNCTION_INFO_V1(sparsevec_lt);
Datum
sparsevec_lt(PG_FUNCTION_ARGS)
{
SparseVector *a = PG_GETARG_SPARSEVEC_P(0);
SparseVector *b = PG_GETARG_SPARSEVEC_P(1);
PG_RETURN_BOOL(sparsevec_cmp_internal(a, b) < 0);
}
/*
* Less than or equal
*/
PGDLLEXPORT PG_FUNCTION_INFO_V1(sparsevec_le);
Datum
sparsevec_le(PG_FUNCTION_ARGS)
{
SparseVector *a = PG_GETARG_SPARSEVEC_P(0);
SparseVector *b = PG_GETARG_SPARSEVEC_P(1);
PG_RETURN_BOOL(sparsevec_cmp_internal(a, b) <= 0);
}
/*
* Equal
*/
PGDLLEXPORT PG_FUNCTION_INFO_V1(sparsevec_eq);
Datum
sparsevec_eq(PG_FUNCTION_ARGS)
{
SparseVector *a = PG_GETARG_SPARSEVEC_P(0);
SparseVector *b = PG_GETARG_SPARSEVEC_P(1);
PG_RETURN_BOOL(sparsevec_cmp_internal(a, b) == 0);
}
/*
* Not equal
*/
PGDLLEXPORT PG_FUNCTION_INFO_V1(sparsevec_ne);
Datum
sparsevec_ne(PG_FUNCTION_ARGS)
{
SparseVector *a = PG_GETARG_SPARSEVEC_P(0);
SparseVector *b = PG_GETARG_SPARSEVEC_P(1);
PG_RETURN_BOOL(sparsevec_cmp_internal(a, b) != 0);
}
/*
* Greater than or equal
*/
PGDLLEXPORT PG_FUNCTION_INFO_V1(sparsevec_ge);
Datum
sparsevec_ge(PG_FUNCTION_ARGS)
{
SparseVector *a = PG_GETARG_SPARSEVEC_P(0);
SparseVector *b = PG_GETARG_SPARSEVEC_P(1);
PG_RETURN_BOOL(sparsevec_cmp_internal(a, b) >= 0);
}
/*
* Greater than
*/
PGDLLEXPORT PG_FUNCTION_INFO_V1(sparsevec_gt);
Datum
sparsevec_gt(PG_FUNCTION_ARGS)
{
SparseVector *a = PG_GETARG_SPARSEVEC_P(0);
SparseVector *b = PG_GETARG_SPARSEVEC_P(1);
PG_RETURN_BOOL(sparsevec_cmp_internal(a, b) > 0);
}
/*
* Compare sparse vectors
*/
PGDLLEXPORT PG_FUNCTION_INFO_V1(sparsevec_cmp);
Datum
sparsevec_cmp(PG_FUNCTION_ARGS)
{
SparseVector *a = PG_GETARG_SPARSEVEC_P(0);
SparseVector *b = PG_GETARG_SPARSEVEC_P(1);
PG_RETURN_INT32(sparsevec_cmp_internal(a, b));
}