diff --git a/CHANGELOG.md b/CHANGELOG.md index bcd4dfd..b8d4dc8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,7 @@ +## 0.5.1 (unreleased) + +- Added `angular_distance` function + ## 0.5.0 (2023-08-28) - Added HNSW index type diff --git a/sql/vector--0.5.0--0.5.1.sql b/sql/vector--0.5.0--0.5.1.sql new file mode 100644 index 0000000..a9d72fe --- /dev/null +++ b/sql/vector--0.5.0--0.5.1.sql @@ -0,0 +1,5 @@ +-- complain if script is sourced in psql, rather than via CREATE EXTENSION +\echo Use "ALTER EXTENSION vector UPDATE TO '0.5.1'" to load this file. \quit + +CREATE FUNCTION angular_distance(vector, vector) RETURNS float8 + AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; diff --git a/sql/vector.sql b/sql/vector.sql index 137931f..69ec142 100644 --- a/sql/vector.sql +++ b/sql/vector.sql @@ -87,6 +87,9 @@ CREATE FUNCTION vector_l2_squared_distance(vector, vector) RETURNS float8 CREATE FUNCTION vector_negative_inner_product(vector, vector) RETURNS float8 AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; +CREATE FUNCTION angular_distance(vector, vector) RETURNS float8 + AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + CREATE FUNCTION vector_spherical_distance(vector, vector) RETURNS float8 AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; diff --git a/src/vector.c b/src/vector.c index d3ebedb..99ef54d 100644 --- a/src/vector.c +++ b/src/vector.c @@ -684,6 +684,49 @@ cosine_distance(PG_FUNCTION_ARGS) PG_RETURN_FLOAT8(1.0 - similarity); } +/* + * Get the angular distance between two vectors + */ +PGDLLEXPORT PG_FUNCTION_INFO_V1(angular_distance); +Datum +angular_distance(PG_FUNCTION_ARGS) +{ + Vector *a = PG_GETARG_VECTOR_P(0); + Vector *b = PG_GETARG_VECTOR_P(1); + float *ax = a->x; + float *bx = b->x; + float distance = 0.0; + float norma = 0.0; + float normb = 0.0; + double similarity; + + CheckDims(a, b); + + /* Auto-vectorized */ + for (int i = 0; i < a->dim; i++) + { + distance += ax[i] * bx[i]; + norma += ax[i] * ax[i]; + normb += bx[i] * bx[i]; + } + + similarity = (double) distance / sqrt((double) norma * (double) normb); + +#ifdef _MSC_VER + /* /fp:fast may not propagate NaN */ + if (isnan(similarity)) + PG_RETURN_FLOAT8(NAN); +#endif + + /* Prevent NaN with acos with loss of precision */ + if (similarity > 1) + similarity = 1; + else if (similarity < -1) + similarity = -1; + + PG_RETURN_FLOAT8(acos(similarity) / M_PI); +} + /* * Get the distance for spherical k-means * Currently uses angular distance since needs to satisfy triangle inequality diff --git a/test/expected/functions.out b/test/expected/functions.out index 16092ba..648e80c 100644 --- a/test/expected/functions.out +++ b/test/expected/functions.out @@ -152,6 +152,50 @@ SELECT l1_distance('[3e38]', '[-3e38]'); Infinity (1 row) +SELECT angular_distance('[1,2]', '[2,4]'); + angular_distance +------------------ + 0 +(1 row) + +SELECT angular_distance('[1,2]', '[0,0]'); + angular_distance +------------------ + NaN +(1 row) + +SELECT angular_distance('[1,1]', '[1,1]'); + angular_distance +------------------ + 0 +(1 row) + +SELECT angular_distance('[1,1]', '[-1,-1]'); + angular_distance +------------------ + 1 +(1 row) + +SELECT angular_distance('[1,2]', '[3]'); +ERROR: different vector dimensions 2 and 1 +SELECT angular_distance('[1,1]', '[1.1,1.1]'); + angular_distance +------------------ + 0 +(1 row) + +SELECT angular_distance('[1,1]', '[-1.1,-1.1]'); + angular_distance +------------------ + 1 +(1 row) + +SELECT angular_distance('[3e38]', '[3e38]'); + angular_distance +------------------ + NaN +(1 row) + SELECT avg(v) FROM unnest(ARRAY['[1,2,3]'::vector, '[3,5,7]']) v; avg ----------- diff --git a/test/sql/functions.sql b/test/sql/functions.sql index 78ec8c1..b70f87b 100644 --- a/test/sql/functions.sql +++ b/test/sql/functions.sql @@ -36,6 +36,15 @@ SELECT l1_distance('[0,0]', '[0,1]'); SELECT l1_distance('[1,2]', '[3]'); SELECT l1_distance('[3e38]', '[-3e38]'); +SELECT angular_distance('[1,2]', '[2,4]'); +SELECT angular_distance('[1,2]', '[0,0]'); +SELECT angular_distance('[1,1]', '[1,1]'); +SELECT angular_distance('[1,1]', '[-1,-1]'); +SELECT angular_distance('[1,2]', '[3]'); +SELECT angular_distance('[1,1]', '[1.1,1.1]'); +SELECT angular_distance('[1,1]', '[-1.1,-1.1]'); +SELECT angular_distance('[3e38]', '[3e38]'); + SELECT avg(v) FROM unnest(ARRAY['[1,2,3]'::vector, '[3,5,7]']) v; SELECT avg(v) FROM unnest(ARRAY['[1,2,3]'::vector, '[3,5,7]', NULL]) v; SELECT avg(v) FROM unnest(ARRAY[]::vector[]) v;