diff --git a/README.md b/README.md index 7a234ff..7e0a4ca 100644 --- a/README.md +++ b/README.md @@ -953,8 +953,9 @@ Function | Description | Added --- | --- | --- cosine_distance(intvec, intvec) → double precision | cosine distance | 0.8.0 inner_product(intvec, intvec) → double precision | inner product | 0.8.0 -l2_distance(intvec, intvec) → double precision | Euclidean distance | 0.8.0 l1_distance(intvec, intvec) → double precision | taxicab distance | 0.8.0 +l2_distance(intvec, intvec) → double precision | Euclidean distance | 0.8.0 +l2_norm(intvec) → double precision | Euclidean norm | 0.8.0 ### Bit Type diff --git a/sql/vector--0.7.4--0.8.0.sql b/sql/vector--0.7.4--0.8.0.sql index b490a94..1447d62 100644 --- a/sql/vector--0.7.4--0.8.0.sql +++ b/sql/vector--0.7.4--0.8.0.sql @@ -42,6 +42,9 @@ CREATE FUNCTION cosine_distance(intvec, intvec) RETURNS float8 CREATE FUNCTION l1_distance(intvec, intvec) RETURNS float8 AS 'MODULE_PATHNAME', 'intvec_l1_distance' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; +CREATE FUNCTION l2_norm(intvec) RETURNS float8 + AS 'MODULE_PATHNAME', 'intvec_l2_norm' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + CREATE FUNCTION intvec_l2_squared_distance(intvec, intvec) RETURNS float8 AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; @@ -96,6 +99,7 @@ CREATE OPERATOR CLASS intvec_cosine_ops FOR TYPE intvec USING hnsw AS OPERATOR 1 <=> (intvec, intvec) FOR ORDER BY float_ops, FUNCTION 1 cosine_distance(intvec, intvec), + FUNCTION 2 l2_norm(intvec), FUNCTION 3 hnsw_intvec_support(internal); CREATE OPERATOR CLASS intvec_l1_ops diff --git a/sql/vector.sql b/sql/vector.sql index 395418c..198f6e7 100644 --- a/sql/vector.sql +++ b/sql/vector.sql @@ -692,6 +692,9 @@ CREATE FUNCTION cosine_distance(intvec, intvec) RETURNS float8 CREATE FUNCTION l1_distance(intvec, intvec) RETURNS float8 AS 'MODULE_PATHNAME', 'intvec_l1_distance' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; +CREATE FUNCTION l2_norm(intvec) RETURNS float8 + AS 'MODULE_PATHNAME', 'intvec_l2_norm' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + -- intvec private functions CREATE FUNCTION intvec_l2_squared_distance(intvec, intvec) RETURNS float8 @@ -756,6 +759,7 @@ CREATE OPERATOR CLASS intvec_cosine_ops FOR TYPE intvec USING hnsw AS OPERATOR 1 <=> (intvec, intvec) FOR ORDER BY float_ops, FUNCTION 1 cosine_distance(intvec, intvec), + FUNCTION 2 l2_norm(intvec), FUNCTION 3 hnsw_intvec_support(internal); CREATE OPERATOR CLASS intvec_l1_ops diff --git a/src/hnswutils.c b/src/hnswutils.c index 2b92a36..be162fa 100644 --- a/src/hnswutils.c +++ b/src/hnswutils.c @@ -159,6 +159,9 @@ HnswOptionalProcInfo(Relation index, uint16 procnum) Datum HnswNormValue(const HnswTypeInfo * typeInfo, Oid collation, Datum value) { + if (!typeInfo->normalize) + return value; + return DirectFunctionCall1Coll(typeInfo->normalize, collation, value); } diff --git a/test/expected/hnsw_intvec.out b/test/expected/hnsw_intvec.out index 53d6136..a25d206 100644 --- a/test/expected/hnsw_intvec.out +++ b/test/expected/hnsw_intvec.out @@ -64,19 +64,18 @@ SELECT * FROM t ORDER BY val <=> '[3,3,3]'; [1,1,1] [1,2,3] [1,2,4] - [0,0,0] -(4 rows) +(3 rows) SELECT COUNT(*) FROM (SELECT * FROM t ORDER BY val <=> '[0,0,0]') t2; count ------- - 4 + 3 (1 row) SELECT COUNT(*) FROM (SELECT * FROM t ORDER BY val <=> (SELECT NULL::intvec)) t2; count ------- - 4 + 3 (1 row) DROP TABLE t;