Added angular_distance function

This commit is contained in:
Andrew Kane
2023-09-01 19:45:59 -07:00
parent 0b0e542ce6
commit 1a0b9d81ce
6 changed files with 108 additions and 0 deletions

View File

@@ -1,3 +1,7 @@
## 0.5.1 (unreleased)
- Added `angular_distance` function
## 0.5.0 (2023-08-28)
- Added HNSW index type

View File

@@ -0,0 +1,5 @@
-- complain if script is sourced in psql, rather than via CREATE EXTENSION
\echo Use "ALTER EXTENSION vector UPDATE TO '0.5.1'" to load this file. \quit
CREATE FUNCTION angular_distance(vector, vector) RETURNS float8
AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE;

View File

@@ -87,6 +87,9 @@ CREATE FUNCTION vector_l2_squared_distance(vector, vector) RETURNS float8
CREATE FUNCTION vector_negative_inner_product(vector, vector) RETURNS float8
AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE;
CREATE FUNCTION angular_distance(vector, vector) RETURNS float8
AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE;
CREATE FUNCTION vector_spherical_distance(vector, vector) RETURNS float8
AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE;

View File

@@ -684,6 +684,49 @@ cosine_distance(PG_FUNCTION_ARGS)
PG_RETURN_FLOAT8(1.0 - similarity);
}
/*
* Get the angular distance between two vectors
*/
PGDLLEXPORT PG_FUNCTION_INFO_V1(angular_distance);
Datum
angular_distance(PG_FUNCTION_ARGS)
{
Vector *a = PG_GETARG_VECTOR_P(0);
Vector *b = PG_GETARG_VECTOR_P(1);
float *ax = a->x;
float *bx = b->x;
float distance = 0.0;
float norma = 0.0;
float normb = 0.0;
double similarity;
CheckDims(a, b);
/* Auto-vectorized */
for (int i = 0; i < a->dim; i++)
{
distance += ax[i] * bx[i];
norma += ax[i] * ax[i];
normb += bx[i] * bx[i];
}
similarity = (double) distance / sqrt((double) norma * (double) normb);
#ifdef _MSC_VER
/* /fp:fast may not propagate NaN */
if (isnan(similarity))
PG_RETURN_FLOAT8(NAN);
#endif
/* Prevent NaN with acos with loss of precision */
if (similarity > 1)
similarity = 1;
else if (similarity < -1)
similarity = -1;
PG_RETURN_FLOAT8(acos(similarity) / M_PI);
}
/*
* Get the distance for spherical k-means
* Currently uses angular distance since needs to satisfy triangle inequality

View File

@@ -152,6 +152,50 @@ SELECT l1_distance('[3e38]', '[-3e38]');
Infinity
(1 row)
SELECT angular_distance('[1,2]', '[2,4]');
angular_distance
------------------
0
(1 row)
SELECT angular_distance('[1,2]', '[0,0]');
angular_distance
------------------
NaN
(1 row)
SELECT angular_distance('[1,1]', '[1,1]');
angular_distance
------------------
0
(1 row)
SELECT angular_distance('[1,1]', '[-1,-1]');
angular_distance
------------------
1
(1 row)
SELECT angular_distance('[1,2]', '[3]');
ERROR: different vector dimensions 2 and 1
SELECT angular_distance('[1,1]', '[1.1,1.1]');
angular_distance
------------------
0
(1 row)
SELECT angular_distance('[1,1]', '[-1.1,-1.1]');
angular_distance
------------------
1
(1 row)
SELECT angular_distance('[3e38]', '[3e38]');
angular_distance
------------------
NaN
(1 row)
SELECT avg(v) FROM unnest(ARRAY['[1,2,3]'::vector, '[3,5,7]']) v;
avg
-----------

View File

@@ -36,6 +36,15 @@ SELECT l1_distance('[0,0]', '[0,1]');
SELECT l1_distance('[1,2]', '[3]');
SELECT l1_distance('[3e38]', '[-3e38]');
SELECT angular_distance('[1,2]', '[2,4]');
SELECT angular_distance('[1,2]', '[0,0]');
SELECT angular_distance('[1,1]', '[1,1]');
SELECT angular_distance('[1,1]', '[-1,-1]');
SELECT angular_distance('[1,2]', '[3]');
SELECT angular_distance('[1,1]', '[1.1,1.1]');
SELECT angular_distance('[1,1]', '[-1.1,-1.1]');
SELECT angular_distance('[3e38]', '[3e38]');
SELECT avg(v) FROM unnest(ARRAY['[1,2,3]'::vector, '[3,5,7]']) v;
SELECT avg(v) FROM unnest(ARRAY['[1,2,3]'::vector, '[3,5,7]', NULL]) v;
SELECT avg(v) FROM unnest(ARRAY[]::vector[]) v;