diff --git a/sql/vector--0.6.2--0.7.0.sql b/sql/vector--0.6.2--0.7.0.sql index 3fdc8f1..26ac3ed 100644 --- a/sql/vector--0.6.2--0.7.0.sql +++ b/sql/vector--0.6.2--0.7.0.sql @@ -86,6 +86,9 @@ CREATE FUNCTION halfvec_l2_squared_distance(halfvec, halfvec) RETURNS float8 CREATE FUNCTION halfvec_negative_inner_product(halfvec, halfvec) RETURNS float8 AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; +CREATE FUNCTION halfvec_spherical_distance(halfvec, halfvec) RETURNS float8 + AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + CREATE FUNCTION halfvec(halfvec, integer, boolean) RETURNS halfvec AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; diff --git a/sql/vector.sql b/sql/vector.sql index c694b1c..8b396da 100644 --- a/sql/vector.sql +++ b/sql/vector.sql @@ -381,6 +381,9 @@ CREATE FUNCTION halfvec_l2_squared_distance(halfvec, halfvec) RETURNS float8 CREATE FUNCTION halfvec_negative_inner_product(halfvec, halfvec) RETURNS float8 AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; +CREATE FUNCTION halfvec_spherical_distance(halfvec, halfvec) RETURNS float8 + AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + -- halfvec cast functions CREATE FUNCTION halfvec(halfvec, integer, boolean) RETURNS halfvec diff --git a/src/halfvec.c b/src/halfvec.c index 4faea8f..bdd86cc 100644 --- a/src/halfvec.c +++ b/src/halfvec.c @@ -879,6 +879,32 @@ halfvec_cosine_distance(PG_FUNCTION_ARGS) PG_RETURN_FLOAT8(1 - similarity); } +/* + * Get the distance for spherical k-means + * Currently uses angular distance since needs to satisfy triangle inequality + * Assumes inputs are unit vectors (skips norm) + */ +PGDLLEXPORT PG_FUNCTION_INFO_V1(halfvec_spherical_distance); +Datum +halfvec_spherical_distance(PG_FUNCTION_ARGS) +{ + HalfVector *a = PG_GETARG_HALFVEC_P(0); + HalfVector *b = PG_GETARG_HALFVEC_P(1); + double distance; + + CheckDims(a, b); + + distance = (double) HalfvecInnerProduct(a->dim, a->x, b->x); + + /* Prevent NaN with acos with loss of precision */ + if (distance > 1) + distance = 1; + else if (distance < -1) + distance = -1; + + PG_RETURN_FLOAT8(acos(distance) / M_PI); +} + /* * Get the L1 distance between two half vectors */