mirror of
https://github.com/pgvector/pgvector.git
synced 2026-06-06 05:51:21 +08:00
Prevent overflow at cost to speed
This commit is contained in:
12
src/vector.c
12
src/vector.c
@@ -754,7 +754,8 @@ normalize_l2(PG_FUNCTION_ARGS)
|
||||
{
|
||||
Vector *a = PG_GETARG_VECTOR_P(0);
|
||||
float *ax = a->x;
|
||||
float norm = 0.0;
|
||||
double norm = 0.0;
|
||||
float normf;
|
||||
Vector *result;
|
||||
float *rx;
|
||||
|
||||
@@ -763,15 +764,16 @@ normalize_l2(PG_FUNCTION_ARGS)
|
||||
|
||||
/* Auto-vectorized */
|
||||
for (int i = 0; i < a->dim; i++)
|
||||
norm += ax[i] * ax[i];
|
||||
norm += (double) ax[i] * (double) ax[i];
|
||||
|
||||
norm = sqrtf(norm);
|
||||
norm = sqrt(norm);
|
||||
normf = (float) norm;
|
||||
|
||||
if (norm > 0)
|
||||
if (normf > 0)
|
||||
{
|
||||
/* Auto-vectorized */
|
||||
for (int i = 0, imax = a->dim; i < imax; i++)
|
||||
rx[i] = ax[i] / norm;
|
||||
rx[i] = ax[i] / normf;
|
||||
|
||||
/* Check for overflow */
|
||||
for (int i = 0, imax = a->dim; i < imax; i++)
|
||||
|
||||
@@ -72,6 +72,12 @@ SELECT normalize_l2('[0,0]');
|
||||
[0,0]
|
||||
(1 row)
|
||||
|
||||
SELECT normalize_l2('[3e38]');
|
||||
normalize_l2
|
||||
--------------
|
||||
[1]
|
||||
(1 row)
|
||||
|
||||
SELECT l2_distance('[0,0]', '[3,4]');
|
||||
l2_distance
|
||||
-------------
|
||||
|
||||
@@ -16,6 +16,7 @@ SELECT normalize_l2('[3,4]');
|
||||
SELECT normalize_l2('[3,0]');
|
||||
SELECT normalize_l2('[0,0.1]');
|
||||
SELECT normalize_l2('[0,0]');
|
||||
SELECT normalize_l2('[3e38]');
|
||||
|
||||
SELECT l2_distance('[0,0]', '[3,4]');
|
||||
SELECT l2_distance('[0,0]', '[0,1]');
|
||||
|
||||
Reference in New Issue
Block a user