From 127ecdd6506086fb416b01e90ab586a9a8fef125 Mon Sep 17 00:00:00 2001 From: Andrew Kane Date: Mon, 15 Apr 2024 14:05:18 -0700 Subject: [PATCH] Added l2_normalize function for sparsevec --- README.md | 1 + sql/vector--0.6.2--0.7.0.sql | 3 ++ sql/vector.sql | 3 ++ src/hnswutils.c | 15 +--------- src/sparsevec.c | 42 +++++++++++++++++++++++++++ src/sparsevec.h | 3 ++ test/expected/sparsevec_functions.out | 30 +++++++++++++++++++ test/sql/sparsevec_functions.sql | 6 ++++ 8 files changed, 89 insertions(+), 14 deletions(-) diff --git a/README.md b/README.md index 1699cfb..36908ad 100644 --- a/README.md +++ b/README.md @@ -952,6 +952,7 @@ inner_product(sparsevec, sparsevec) → double precision | inner product | unrel l1_distance(sparsevec, sparsevec) → double precision | taxicab distance | unreleased l2_distance(sparsevec, sparsevec) → double precision | Euclidean distance | unreleased l2_norm(sparsevec) → double precision | Euclidean norm | unreleased +l2_normalize(sparsevec) → sparsevec | Normalize with Euclidean norm | unreleased ## Installation Notes - Linux and Mac diff --git a/sql/vector--0.6.2--0.7.0.sql b/sql/vector--0.6.2--0.7.0.sql index 16fe5ef..97b352f 100644 --- a/sql/vector--0.6.2--0.7.0.sql +++ b/sql/vector--0.6.2--0.7.0.sql @@ -364,6 +364,9 @@ CREATE FUNCTION l1_distance(sparsevec, sparsevec) RETURNS float8 CREATE FUNCTION l2_norm(sparsevec) RETURNS float8 AS 'MODULE_PATHNAME', 'sparsevec_l2_norm' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; +CREATE FUNCTION l2_normalize(sparsevec) RETURNS sparsevec + AS 'MODULE_PATHNAME', 'sparsevec_l2_normalize' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + CREATE FUNCTION sparsevec_lt(sparsevec, sparsevec) RETURNS bool AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; diff --git a/sql/vector.sql b/sql/vector.sql index b31ea54..8202425 100644 --- a/sql/vector.sql +++ b/sql/vector.sql @@ -673,6 +673,9 @@ CREATE FUNCTION l1_distance(sparsevec, sparsevec) RETURNS float8 CREATE FUNCTION l2_norm(sparsevec) RETURNS float8 AS 'MODULE_PATHNAME', 'sparsevec_l2_norm' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; +CREATE FUNCTION l2_normalize(sparsevec) RETURNS sparsevec + AS 'MODULE_PATHNAME', 'sparsevec_l2_normalize' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + -- sparsevec private functions CREATE FUNCTION sparsevec_lt(sparsevec, sparsevec) RETURNS bool diff --git a/src/hnswutils.c b/src/hnswutils.c index 5071dfa..bd6dd5d 100644 --- a/src/hnswutils.c +++ b/src/hnswutils.c @@ -213,20 +213,7 @@ HnswNormValue(FmgrInfo *procinfo, Oid collation, Datum *value, HnswType type) else if (type == HNSW_TYPE_HALFVEC) *value = DirectFunctionCall1(halfvec_l2_normalize, *value); else if (type == HNSW_TYPE_SPARSEVEC) - { - SparseVector *v = DatumGetSparseVector(*value); - SparseVector *result = InitSparseVector(v->dim, v->nnz); - float *vx = SPARSEVEC_VALUES(v); - float *rx = SPARSEVEC_VALUES(result); - - for (int i = 0; i < v->nnz; i++) - { - result->indices[i] = v->indices[i]; - rx[i] = vx[i] / norm; - } - - *value = PointerGetDatum(result); - } + *value = DirectFunctionCall1(sparsevec_l2_normalize, *value); else elog(ERROR, "Unsupported type"); diff --git a/src/sparsevec.c b/src/sparsevec.c index 073e031..6011d92 100644 --- a/src/sparsevec.c +++ b/src/sparsevec.c @@ -848,6 +848,48 @@ sparsevec_l2_norm(PG_FUNCTION_ARGS) PG_RETURN_FLOAT8(sqrt(norm)); } +/* + * Normalize a sparse vector with the L2 norm + */ +PGDLLEXPORT PG_FUNCTION_INFO_V1(sparsevec_l2_normalize); +Datum +sparsevec_l2_normalize(PG_FUNCTION_ARGS) +{ + SparseVector *a = PG_GETARG_SPARSEVEC_P(0); + float *ax = SPARSEVEC_VALUES(a); + double norm = 0; + SparseVector *result; + float *rx; + + result = InitSparseVector(a->dim, a->nnz); + rx = SPARSEVEC_VALUES(result); + + /* Auto-vectorized */ + for (int i = 0; i < a->nnz; i++) + norm += (double) ax[i] * (double) ax[i]; + + norm = sqrt(norm); + + /* Return zero vector for zero norm */ + if (norm > 0) + { + for (int i = 0; i < a->nnz; i++) + { + result->indices[i] = a->indices[i]; + rx[i] = ax[i] / norm; + } + + /* Check for overflow */ + for (int i = 0; i < a->nnz; i++) + { + if (isinf(rx[i])) + float_overflow_error(); + } + } + + PG_RETURN_POINTER(result); +} + /* * Internal helper to compare sparse vectors */ diff --git a/src/sparsevec.h b/src/sparsevec.h index 1d79957..efba1bf 100644 --- a/src/sparsevec.h +++ b/src/sparsevec.h @@ -1,6 +1,8 @@ #ifndef SPARSEVEC_H #define SPARSEVEC_H +#include "fmgr.h" + #define SPARSEVEC_MAX_DIM 100000 #define SPARSEVEC_MAX_NNZ 16000 @@ -21,5 +23,6 @@ typedef struct SparseVector } SparseVector; SparseVector *InitSparseVector(int dim, int nnz); +Datum sparsevec_l2_normalize(PG_FUNCTION_ARGS); #endif diff --git a/test/expected/sparsevec_functions.out b/test/expected/sparsevec_functions.out index 0b76d54..afa0250 100644 --- a/test/expected/sparsevec_functions.out +++ b/test/expected/sparsevec_functions.out @@ -292,3 +292,33 @@ SELECT l1_distance('{1:1,3:3,5:5,7:7,9:9}/9'::sparsevec, '{2:2,4:4,6:6,8:8}/9'); 45 (1 row) +SELECT l2_normalize('{1:3,2:4}/2'::sparsevec); + l2_normalize +----------------- + {1:0.6,2:0.8}/2 +(1 row) + +SELECT l2_normalize('{1:3}/2'::sparsevec); + l2_normalize +-------------- + {1:1}/2 +(1 row) + +SELECT l2_normalize('{2:0.1}/2'::sparsevec); + l2_normalize +-------------- + {2:1}/2 +(1 row) + +SELECT l2_normalize('{}/2'::sparsevec); + l2_normalize +-------------- + {}/2 +(1 row) + +SELECT l2_normalize('{1:3e38}/1'::sparsevec); + l2_normalize +-------------- + {1:1}/1 +(1 row) + diff --git a/test/sql/sparsevec_functions.sql b/test/sql/sparsevec_functions.sql index e65ac46..d260dbd 100644 --- a/test/sql/sparsevec_functions.sql +++ b/test/sql/sparsevec_functions.sql @@ -55,3 +55,9 @@ SELECT l1_distance('{1:1,2:2}/2'::sparsevec, '{1:3}/1'); SELECT l1_distance('{1:3e38}/1'::sparsevec, '{1:-3e38}/1'); SELECT l1_distance('{1:1,3:3,5:5,7:7}/8'::sparsevec, '{2:2,4:4,6:6,8:8}/8'); SELECT l1_distance('{1:1,3:3,5:5,7:7,9:9}/9'::sparsevec, '{2:2,4:4,6:6,8:8}/9'); + +SELECT l2_normalize('{1:3,2:4}/2'::sparsevec); +SELECT l2_normalize('{1:3}/2'::sparsevec); +SELECT l2_normalize('{2:0.1}/2'::sparsevec); +SELECT l2_normalize('{}/2'::sparsevec); +SELECT l2_normalize('{1:3e38}/1'::sparsevec);