diff --git a/CHANGELOG.md b/CHANGELOG.md index 27c4bbf..4665abd 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -3,6 +3,7 @@ - Added HNSW index type - Added support for parallel index builds - Added `l1_distance` function +- Added `normalize_l2` function - Added element-wise multiplication for vectors - Added `sum` aggregate - Improved performance of distance functions diff --git a/README.md b/README.md index 4dc41c7..48b7118 100644 --- a/README.md +++ b/README.md @@ -392,6 +392,7 @@ cosine_distance(vector, vector) → double precision | cosine distance inner_product(vector, vector) → double precision | inner product l2_distance(vector, vector) → double precision | Euclidean distance l1_distance(vector, vector) → double precision | taxicab distance [unreleased] +normalize_l2(vector) → vector | normalize with Euclidean norm [unreleased] vector_dims(vector) → integer | number of dimensions vector_norm(vector) → double precision | Euclidean norm diff --git a/sql/vector--0.4.4--0.5.0.sql b/sql/vector--0.4.4--0.5.0.sql index 48572bf..c503e79 100644 --- a/sql/vector--0.4.4--0.5.0.sql +++ b/sql/vector--0.4.4--0.5.0.sql @@ -4,6 +4,9 @@ CREATE FUNCTION l1_distance(vector, vector) RETURNS float8 AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; +CREATE FUNCTION normalize_l2(vector) RETURNS vector + AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + CREATE FUNCTION vector_mul(vector, vector) RETURNS vector AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; diff --git a/sql/vector.sql b/sql/vector.sql index 137931f..fb5fbdc 100644 --- a/sql/vector.sql +++ b/sql/vector.sql @@ -49,6 +49,9 @@ CREATE FUNCTION vector_dims(vector) RETURNS integer CREATE FUNCTION vector_norm(vector) RETURNS float8 AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; +CREATE FUNCTION normalize_l2(vector) RETURNS vector + AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + CREATE FUNCTION vector_add(vector, vector) RETURNS vector AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; diff --git a/src/vector.c b/src/vector.c index 02964d6..8fada6b 100644 --- a/src/vector.c +++ b/src/vector.c @@ -745,6 +745,45 @@ vector_norm(PG_FUNCTION_ARGS) PG_RETURN_FLOAT8(sqrt((double) norm)); } +/* + * Normalize a vector with the L2 norm + */ +PGDLLEXPORT PG_FUNCTION_INFO_V1(normalize_l2); +Datum +normalize_l2(PG_FUNCTION_ARGS) +{ + Vector *a = PG_GETARG_VECTOR_P(0); + float *ax = a->x; + float norm = 0.0; + Vector *result; + float *rx; + + result = InitVector(a->dim); + rx = result->x; + + /* Auto-vectorized */ + for (int i = 0; i < a->dim; i++) + norm += ax[i] * ax[i]; + + norm = sqrtf(norm); + + if (norm > 0) + { + /* Auto-vectorized */ + for (int i = 0, imax = a->dim; i < imax; i++) + rx[i] = ax[i] / norm; + + /* Check for overflow */ + for (int i = 0, imax = a->dim; i < imax; i++) + { + if (isinf(rx[i])) + float_overflow_error(); + } + } + + PG_RETURN_POINTER(result); +} + /* * Add vectors */ diff --git a/test/expected/functions.out b/test/expected/functions.out index 6e83da0..aa014ba 100644 --- a/test/expected/functions.out +++ b/test/expected/functions.out @@ -48,6 +48,30 @@ SELECT vector_norm('[0,1]'); 1 (1 row) +SELECT normalize_l2('[3,4]'); + normalize_l2 +-------------- + [0.6,0.8] +(1 row) + +SELECT normalize_l2('[3,0]'); + normalize_l2 +-------------- + [1,0] +(1 row) + +SELECT normalize_l2('[0,0.1]'); + normalize_l2 +-------------- + [0,1] +(1 row) + +SELECT normalize_l2('[0,0]'); + normalize_l2 +-------------- + [0,0] +(1 row) + SELECT l2_distance('[0,0]', '[3,4]'); l2_distance ------------- diff --git a/test/sql/functions.sql b/test/sql/functions.sql index c71291a..56f896b 100644 --- a/test/sql/functions.sql +++ b/test/sql/functions.sql @@ -12,6 +12,11 @@ SELECT round(vector_norm('[1,1]')::numeric, 5); SELECT vector_norm('[3,4]'); SELECT vector_norm('[0,1]'); +SELECT normalize_l2('[3,4]'); +SELECT normalize_l2('[3,0]'); +SELECT normalize_l2('[0,0.1]'); +SELECT normalize_l2('[0,0]'); + SELECT l2_distance('[0,0]', '[3,4]'); SELECT l2_distance('[0,0]', '[0,1]'); SELECT l2_distance('[1,2]', '[3]');