From 14b351bc929b164e6e0b74aa5c4e937b458f3904 Mon Sep 17 00:00:00 2001 From: Heikki Linnakangas Date: Wed, 24 Apr 2024 11:20:16 +0300 Subject: [PATCH] Fix integer overflow in subvector() function (#530) `end = start + count` can overflow if `start` is very large. That leads to a segfault later in the function. Add test case for it. --- src/halfvec.c | 25 ++++++++++++++++++++----- src/vector.c | 23 +++++++++++++++++++---- test/expected/halfvec_functions.out | 2 ++ test/expected/vector_functions.out | 2 ++ test/sql/halfvec_functions.sql | 1 + test/sql/vector_functions.sql | 1 + 6 files changed, 45 insertions(+), 9 deletions(-) diff --git a/src/halfvec.c b/src/halfvec.c index 508ddc3..81468c3 100644 --- a/src/halfvec.c +++ b/src/halfvec.c @@ -934,17 +934,32 @@ halfvec_subvector(PG_FUNCTION_ARGS) HalfVector *a = PG_GETARG_HALFVEC_P(0); int32 start = PG_GETARG_INT32(1); int32 count = PG_GETARG_INT32(2); - int32 end = start + count; + int32 end; half *ax = a->x; HalfVector *result; - int dim; + int32 dim; + + if (count < 1) + ereport(ERROR, + (errcode(ERRCODE_DATA_EXCEPTION), + errmsg("halfvec must have at least 1 dimension"))); + + /* + * Check if (start + count > a->dim), avoiding integer overflow. a->dim + * and count are both positive, so a->dim - count won't overflow. + */ + if (start > a->dim - count) + end = a->dim + 1; + else + end = start + count; /* Indexing starts at 1, like substring */ if (start < 1) start = 1; - - if (end > a->dim) - end = a->dim + 1; + else if (start > a->dim) + ereport(ERROR, + (errcode(ERRCODE_DATA_EXCEPTION), + errmsg("halfvec must have at least 1 dimension"))); dim = end - start; CheckDim(dim); diff --git a/src/vector.c b/src/vector.c index 63fb529..098df23 100644 --- a/src/vector.c +++ b/src/vector.c @@ -980,17 +980,32 @@ subvector(PG_FUNCTION_ARGS) Vector *a = PG_GETARG_VECTOR_P(0); int32 start = PG_GETARG_INT32(1); int32 count = PG_GETARG_INT32(2); - int32 end = start + count; + int32 end; float *ax = a->x; Vector *result; int dim; + if (count < 1) + ereport(ERROR, + (errcode(ERRCODE_DATA_EXCEPTION), + errmsg("vector must have at least 1 dimension"))); + + /* + * Check if (start + count > a->dim), avoiding integer overflow. a->dim + * and count are both positive, so a->dim - count won't overflow. + */ + if (start > a->dim - count) + end = a->dim + 1; + else + end = start + count; + /* Indexing starts at 1, like substring */ if (start < 1) start = 1; - - if (end > a->dim) - end = a->dim + 1; + else if (start > a->dim) + ereport(ERROR, + (errcode(ERRCODE_DATA_EXCEPTION), + errmsg("vector must have at least 1 dimension"))); dim = end - start; CheckDim(dim); diff --git a/test/expected/halfvec_functions.out b/test/expected/halfvec_functions.out index 652430d..5bb4be0 100644 --- a/test/expected/halfvec_functions.out +++ b/test/expected/halfvec_functions.out @@ -406,6 +406,8 @@ SELECT subvector('[1,2,3,4,5]'::halfvec, 3, -1); ERROR: halfvec must have at least 1 dimension SELECT subvector('[1,2,3,4,5]'::halfvec, -1, 2); ERROR: halfvec must have at least 1 dimension +SELECT subvector('[1,2,3,4,5]'::halfvec, 2147483647, 10); +ERROR: halfvec must have at least 1 dimension SELECT avg(v) FROM unnest(ARRAY['[1,2,3]'::halfvec, '[3,5,7]']) v; avg ----------- diff --git a/test/expected/vector_functions.out b/test/expected/vector_functions.out index 15b18eb..3d3e14b 100644 --- a/test/expected/vector_functions.out +++ b/test/expected/vector_functions.out @@ -430,6 +430,8 @@ SELECT subvector('[1,2,3,4,5]'::vector, 3, -1); ERROR: vector must have at least 1 dimension SELECT subvector('[1,2,3,4,5]'::vector, -1, 2); ERROR: vector must have at least 1 dimension +SELECT subvector('[1,2,3,4,5]'::vector, 2147483647, 10); +ERROR: vector must have at least 1 dimension SELECT avg(v) FROM unnest(ARRAY['[1,2,3]'::vector, '[3,5,7]']) v; avg ----------- diff --git a/test/sql/halfvec_functions.sql b/test/sql/halfvec_functions.sql index c883c57..3810b8e 100644 --- a/test/sql/halfvec_functions.sql +++ b/test/sql/halfvec_functions.sql @@ -90,6 +90,7 @@ SELECT subvector('[1,2,3,4,5]'::halfvec, 3, 9); SELECT subvector('[1,2,3,4,5]'::halfvec, 1, 0); SELECT subvector('[1,2,3,4,5]'::halfvec, 3, -1); SELECT subvector('[1,2,3,4,5]'::halfvec, -1, 2); +SELECT subvector('[1,2,3,4,5]'::halfvec, 2147483647, 10); SELECT avg(v) FROM unnest(ARRAY['[1,2,3]'::halfvec, '[3,5,7]']) v; SELECT avg(v) FROM unnest(ARRAY['[1,2,3]'::halfvec, '[3,5,7]', NULL]) v; diff --git a/test/sql/vector_functions.sql b/test/sql/vector_functions.sql index d2dae5a..fb95a74 100644 --- a/test/sql/vector_functions.sql +++ b/test/sql/vector_functions.sql @@ -94,6 +94,7 @@ SELECT subvector('[1,2,3,4,5]'::vector, 3, 9); SELECT subvector('[1,2,3,4,5]'::vector, 1, 0); SELECT subvector('[1,2,3,4,5]'::vector, 3, -1); SELECT subvector('[1,2,3,4,5]'::vector, -1, 2); +SELECT subvector('[1,2,3,4,5]'::vector, 2147483647, 10); SELECT avg(v) FROM unnest(ARRAY['[1,2,3]'::vector, '[3,5,7]']) v; SELECT avg(v) FROM unnest(ARRAY['[1,2,3]'::vector, '[3,5,7]', NULL]) v;