diff --git a/CHANGELOG.md b/CHANGELOG.md index 4acdfb6..82e589b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,7 @@ - Added element-wise multiplication for vectors - Added `sum` aggregate - Improved performance of distance functions +- Fixed out of range results for cosine distance ## 0.4.4 (2023-06-12) diff --git a/src/vector.c b/src/vector.c index 78dc4e4..6709b0d 100644 --- a/src/vector.c +++ b/src/vector.c @@ -630,6 +630,7 @@ cosine_distance(PG_FUNCTION_ARGS) float distance = 0.0; float norma = 0.0; float normb = 0.0; + double similarity; CheckDims(a, b); @@ -642,7 +643,15 @@ cosine_distance(PG_FUNCTION_ARGS) } /* Use sqrt(a * b) over sqrt(a) * sqrt(b) */ - PG_RETURN_FLOAT8(1.0 - ((double) distance / sqrt((double) norma * (double) normb))); + similarity = (double) distance / sqrt((double) norma * (double) normb); + + /* Keep in range */ + if (similarity > 1) + similarity = 1.0; + else if (similarity < -1) + similarity = -1.0; + + PG_RETURN_FLOAT8(1.0 - similarity); } /* diff --git a/test/expected/functions.out b/test/expected/functions.out index 46bb8fe..268706a 100644 --- a/test/expected/functions.out +++ b/test/expected/functions.out @@ -96,6 +96,18 @@ SELECT cosine_distance('[1,1]', '[-1,-1]'); SELECT cosine_distance('[1,2]', '[3]'); ERROR: different vector dimensions 2 and 1 +SELECT cosine_distance(array_fill(0.1, ARRAY[1536])::vector, array_fill(0.111, ARRAY[1536])::vector); + cosine_distance +----------------- + 0 +(1 row) + +SELECT cosine_distance(array_fill(0.1, ARRAY[1536])::vector, array_fill(-0.111, ARRAY[1536])::vector); + cosine_distance +----------------- + 2 +(1 row) + SELECT l1_distance('[0,0]', '[3,4]'); l1_distance ------------- diff --git a/test/sql/functions.sql b/test/sql/functions.sql index fa29717..26bf567 100644 --- a/test/sql/functions.sql +++ b/test/sql/functions.sql @@ -24,6 +24,8 @@ SELECT cosine_distance('[1,2]', '[0,0]'); SELECT cosine_distance('[1,1]', '[1,1]'); SELECT cosine_distance('[1,1]', '[-1,-1]'); SELECT cosine_distance('[1,2]', '[3]'); +SELECT cosine_distance(array_fill(0.1, ARRAY[1536])::vector, array_fill(0.111, ARRAY[1536])::vector); +SELECT cosine_distance(array_fill(0.1, ARRAY[1536])::vector, array_fill(-0.111, ARRAY[1536])::vector); SELECT l1_distance('[0,0]', '[3,4]'); SELECT l1_distance('[0,0]', '[0,1]');