Fix integer overflow in subvector() function (#530)

`end = start + count` can overflow if `start` is very large. That
leads to a segfault later in the function. Add test case for it.
This commit is contained in:
Heikki Linnakangas
2024-04-24 11:20:16 +03:00
committed by GitHub
parent ad3f811fa3
commit 14b351bc92
6 changed files with 45 additions and 9 deletions

View File

@@ -934,17 +934,32 @@ halfvec_subvector(PG_FUNCTION_ARGS)
HalfVector *a = PG_GETARG_HALFVEC_P(0);
int32 start = PG_GETARG_INT32(1);
int32 count = PG_GETARG_INT32(2);
int32 end = start + count;
int32 end;
half *ax = a->x;
HalfVector *result;
int dim;
int32 dim;
if (count < 1)
ereport(ERROR,
(errcode(ERRCODE_DATA_EXCEPTION),
errmsg("halfvec must have at least 1 dimension")));
/*
* Check if (start + count > a->dim), avoiding integer overflow. a->dim
* and count are both positive, so a->dim - count won't overflow.
*/
if (start > a->dim - count)
end = a->dim + 1;
else
end = start + count;
/* Indexing starts at 1, like substring */
if (start < 1)
start = 1;
if (end > a->dim)
end = a->dim + 1;
else if (start > a->dim)
ereport(ERROR,
(errcode(ERRCODE_DATA_EXCEPTION),
errmsg("halfvec must have at least 1 dimension")));
dim = end - start;
CheckDim(dim);

View File

@@ -980,17 +980,32 @@ subvector(PG_FUNCTION_ARGS)
Vector *a = PG_GETARG_VECTOR_P(0);
int32 start = PG_GETARG_INT32(1);
int32 count = PG_GETARG_INT32(2);
int32 end = start + count;
int32 end;
float *ax = a->x;
Vector *result;
int dim;
if (count < 1)
ereport(ERROR,
(errcode(ERRCODE_DATA_EXCEPTION),
errmsg("vector must have at least 1 dimension")));
/*
* Check if (start + count > a->dim), avoiding integer overflow. a->dim
* and count are both positive, so a->dim - count won't overflow.
*/
if (start > a->dim - count)
end = a->dim + 1;
else
end = start + count;
/* Indexing starts at 1, like substring */
if (start < 1)
start = 1;
if (end > a->dim)
end = a->dim + 1;
else if (start > a->dim)
ereport(ERROR,
(errcode(ERRCODE_DATA_EXCEPTION),
errmsg("vector must have at least 1 dimension")));
dim = end - start;
CheckDim(dim);

View File

@@ -406,6 +406,8 @@ SELECT subvector('[1,2,3,4,5]'::halfvec, 3, -1);
ERROR: halfvec must have at least 1 dimension
SELECT subvector('[1,2,3,4,5]'::halfvec, -1, 2);
ERROR: halfvec must have at least 1 dimension
SELECT subvector('[1,2,3,4,5]'::halfvec, 2147483647, 10);
ERROR: halfvec must have at least 1 dimension
SELECT avg(v) FROM unnest(ARRAY['[1,2,3]'::halfvec, '[3,5,7]']) v;
avg
-----------

View File

@@ -430,6 +430,8 @@ SELECT subvector('[1,2,3,4,5]'::vector, 3, -1);
ERROR: vector must have at least 1 dimension
SELECT subvector('[1,2,3,4,5]'::vector, -1, 2);
ERROR: vector must have at least 1 dimension
SELECT subvector('[1,2,3,4,5]'::vector, 2147483647, 10);
ERROR: vector must have at least 1 dimension
SELECT avg(v) FROM unnest(ARRAY['[1,2,3]'::vector, '[3,5,7]']) v;
avg
-----------

View File

@@ -90,6 +90,7 @@ SELECT subvector('[1,2,3,4,5]'::halfvec, 3, 9);
SELECT subvector('[1,2,3,4,5]'::halfvec, 1, 0);
SELECT subvector('[1,2,3,4,5]'::halfvec, 3, -1);
SELECT subvector('[1,2,3,4,5]'::halfvec, -1, 2);
SELECT subvector('[1,2,3,4,5]'::halfvec, 2147483647, 10);
SELECT avg(v) FROM unnest(ARRAY['[1,2,3]'::halfvec, '[3,5,7]']) v;
SELECT avg(v) FROM unnest(ARRAY['[1,2,3]'::halfvec, '[3,5,7]', NULL]) v;

View File

@@ -94,6 +94,7 @@ SELECT subvector('[1,2,3,4,5]'::vector, 3, 9);
SELECT subvector('[1,2,3,4,5]'::vector, 1, 0);
SELECT subvector('[1,2,3,4,5]'::vector, 3, -1);
SELECT subvector('[1,2,3,4,5]'::vector, -1, 2);
SELECT subvector('[1,2,3,4,5]'::vector, 2147483647, 10);
SELECT avg(v) FROM unnest(ARRAY['[1,2,3]'::vector, '[3,5,7]']) v;
SELECT avg(v) FROM unnest(ARRAY['[1,2,3]'::vector, '[3,5,7]', NULL]) v;