From fd4fbd238c4ced0c621b8e01a8b071587fea3451 Mon Sep 17 00:00:00 2001 From: Andrew Kane Date: Fri, 19 Apr 2024 16:54:19 -0700 Subject: [PATCH] Updated sparsevec input to support indices in any order [skip ci] --- src/sparsevec.c | 37 ++++++++++++++++++++++++------- test/expected/sparsevec_input.out | 14 ++++++++++-- test/sql/sparsevec_input.sql | 2 ++ 3 files changed, 43 insertions(+), 10 deletions(-) diff --git a/src/sparsevec.c b/src/sparsevec.c index f1726be..22f7602 100644 --- a/src/sparsevec.c +++ b/src/sparsevec.c @@ -18,6 +18,12 @@ #include "utils/builtins.h" #endif +typedef struct SparseInputElement +{ + int32 index; + float value; +} SparseInputElement; + /* * Ensure same dimensions */ @@ -164,6 +170,21 @@ sparsevec_isspace(char ch) return false; } +/* + * Compare indices + */ +static int +CompareIndices(const void *a, const void *b) +{ + if (((SparseInputElement *) a)->index < ((SparseInputElement *) b)->index) + return -1; + + if (((SparseInputElement *) a)->index > ((SparseInputElement *) b)->index) + return 1; + + return 0; +} + /* * Convert textual representation to internal representation */ @@ -178,8 +199,7 @@ sparsevec_in(PG_FUNCTION_ARGS) char *stringEnd; SparseVector *result; float *rvalues; - int32 *indices; - float *values; + SparseInputElement *elements; int maxNnz; int nnz = 0; @@ -197,8 +217,7 @@ sparsevec_in(PG_FUNCTION_ARGS) (errcode(ERRCODE_PROGRAM_LIMIT_EXCEEDED), errmsg("sparsevec cannot have more than %d non-zero elements", SPARSEVEC_MAX_NNZ))); - indices = palloc(maxNnz * sizeof(int32)); - values = palloc(maxNnz * sizeof(float)); + elements = palloc(maxNnz * sizeof(SparseInputElement)); pt = lit; @@ -291,8 +310,8 @@ sparsevec_in(PG_FUNCTION_ARGS) /* Do not store zero values */ if (value != 0) { - indices[nnz] = index; - values[nnz] = value; + elements[nnz].index = index; + elements[nnz].value = value; nnz++; } @@ -353,12 +372,14 @@ sparsevec_in(PG_FUNCTION_ARGS) CheckDim(dim); CheckExpectedDim(typmod, dim); + qsort(elements, nnz, sizeof(SparseInputElement), CompareIndices); + result = InitSparseVector(dim, nnz); rvalues = SPARSEVEC_VALUES(result); for (int i = 0; i < nnz; i++) { - result->indices[i] = indices[i]; - rvalues[i] = values[i]; + result->indices[i] = elements[i].index; + rvalues[i] = elements[i].value; CheckIndex(result->indices, i, dim); } diff --git a/test/expected/sparsevec_input.out b/test/expected/sparsevec_input.out index 2e8db67..a654aa6 100644 --- a/test/expected/sparsevec_input.out +++ b/test/expected/sparsevec_input.out @@ -164,8 +164,18 @@ SELECT '{1:0,2:1,3:0}/3'::sparsevec; (1 row) SELECT '{2:1,1:1}/2'::sparsevec; -ERROR: indexes must be in ascending order -LINE 1: SELECT '{2:1,1:1}/2'::sparsevec; + sparsevec +------------- + {1:1,2:1}/2 +(1 row) + +SELECT '{1:1,1:1}/2'::sparsevec; +ERROR: indexes must not contain duplicates +LINE 1: SELECT '{1:1,1:1}/2'::sparsevec; + ^ +SELECT '{1:1,2:1,1:1}/2'::sparsevec; +ERROR: indexes must not contain duplicates +LINE 1: SELECT '{1:1,2:1,1:1}/2'::sparsevec; ^ SELECT '{}/5'::sparsevec; sparsevec diff --git a/test/sql/sparsevec_input.sql b/test/sql/sparsevec_input.sql index 0c7dd6c..4b665bf 100644 --- a/test/sql/sparsevec_input.sql +++ b/test/sql/sparsevec_input.sql @@ -34,6 +34,8 @@ SELECT '{1:1a}/1'::sparsevec; SELECT '{1:1,}/1'::sparsevec; SELECT '{1:0,2:1,3:0}/3'::sparsevec; SELECT '{2:1,1:1}/2'::sparsevec; +SELECT '{1:1,1:1}/2'::sparsevec; +SELECT '{1:1,2:1,1:1}/2'::sparsevec; SELECT '{}/5'::sparsevec; SELECT '{}/-1'::sparsevec; SELECT '{}/100001'::sparsevec;