diff --git a/sql/vector--0.6.2--0.7.0.sql b/sql/vector--0.6.2--0.7.0.sql index 3b8e346..57e2bc5 100644 --- a/sql/vector--0.6.2--0.7.0.sql +++ b/sql/vector--0.6.2--0.7.0.sql @@ -39,6 +39,9 @@ CREATE FUNCTION cosine_distance(intvec, intvec) RETURNS float8 CREATE FUNCTION l1_distance(intvec, intvec) RETURNS float8 AS 'MODULE_PATHNAME', 'intvec_l1_distance' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; +CREATE FUNCTION l2_norm(intvec) RETURNS float8 + AS 'MODULE_PATHNAME', 'intvec_l2_norm' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + CREATE FUNCTION intvec_l2_squared_distance(intvec, intvec) RETURNS float8 AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; @@ -85,4 +88,5 @@ CREATE OPERATOR CLASS intvec_ip_ops CREATE OPERATOR CLASS intvec_cosine_ops FOR TYPE intvec USING hnsw AS OPERATOR 1 <=> (intvec, intvec) FOR ORDER BY float_ops, - FUNCTION 1 cosine_distance(intvec, intvec); + FUNCTION 1 cosine_distance(intvec, intvec), + FUNCTION 2 l2_norm(intvec); diff --git a/sql/vector.sql b/sql/vector.sql index 05d97a2..ce7b744 100644 --- a/sql/vector.sql +++ b/sql/vector.sql @@ -330,6 +330,9 @@ CREATE FUNCTION cosine_distance(intvec, intvec) RETURNS float8 CREATE FUNCTION l1_distance(intvec, intvec) RETURNS float8 AS 'MODULE_PATHNAME', 'intvec_l1_distance' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; +CREATE FUNCTION l2_norm(intvec) RETURNS float8 + AS 'MODULE_PATHNAME', 'intvec_l2_norm' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + -- intvec private functions CREATE FUNCTION intvec_l2_squared_distance(intvec, intvec) RETURNS float8 @@ -386,4 +389,5 @@ CREATE OPERATOR CLASS intvec_ip_ops CREATE OPERATOR CLASS intvec_cosine_ops FOR TYPE intvec USING hnsw AS OPERATOR 1 <=> (intvec, intvec) FOR ORDER BY float_ops, - FUNCTION 1 cosine_distance(intvec, intvec); + FUNCTION 1 cosine_distance(intvec, intvec), + FUNCTION 2 l2_norm(intvec); diff --git a/src/hnswutils.c b/src/hnswutils.c index 0104c81..d41ccad 100644 --- a/src/hnswutils.c +++ b/src/hnswutils.c @@ -204,6 +204,10 @@ HnswNormValue(FmgrInfo *procinfo, Oid collation, Datum *value, HnswType type) *value = PointerGetDatum(result); } + else if (type == HNSW_TYPE_INTVEC) + { + /* Do nothing */ + } else elog(ERROR, "Unsupported type"); diff --git a/src/intvec.c b/src/intvec.c index 879a369..4aa9e90 100644 --- a/src/intvec.c +++ b/src/intvec.c @@ -588,3 +588,21 @@ intvec_l1_distance(PG_FUNCTION_ARGS) PG_RETURN_FLOAT8((double) distance); } + +/* + * Get the L2 norm of an int vector + */ +PGDLLEXPORT PG_FUNCTION_INFO_V1(intvec_l2_norm); +Datum +intvec_l2_norm(PG_FUNCTION_ARGS) +{ + IntVector *a = PG_GETARG_INTVEC_P(0); + int8 *ax = a->x; + int norm = 0; + + /* Auto-vectorized */ + for (int i = 0; i < a->dim; i++) + norm += ax[i] * ax[i]; + + PG_RETURN_FLOAT8(sqrt((double) norm)); +} diff --git a/test/expected/hnsw_intvec_cosine.out b/test/expected/hnsw_intvec_cosine.out index 306e685..8a21c74 100644 --- a/test/expected/hnsw_intvec_cosine.out +++ b/test/expected/hnsw_intvec_cosine.out @@ -9,19 +9,18 @@ SELECT * FROM t ORDER BY val <=> '[3,3,3]'; [1,1,1] [1,2,3] [1,2,4] - [0,0,0] -(4 rows) +(3 rows) SELECT COUNT(*) FROM (SELECT * FROM t ORDER BY val <=> '[0,0,0]') t2; count ------- - 4 + 3 (1 row) SELECT COUNT(*) FROM (SELECT * FROM t ORDER BY val <=> (SELECT NULL::intvec)) t2; count ------- - 4 + 3 (1 row) DROP TABLE t;