Added normalize_l2 function

This commit is contained in:
Andrew Kane
2023-08-09 11:29:14 -07:00
parent 4b887a98ae
commit 47e361a93d
7 changed files with 76 additions and 0 deletions

View File

@@ -3,6 +3,7 @@
- Added HNSW index type
- Added support for parallel index builds
- Added `l1_distance` function
- Added `normalize_l2` function
- Added element-wise multiplication for vectors
- Added `sum` aggregate
- Improved performance of distance functions

View File

@@ -392,6 +392,7 @@ cosine_distance(vector, vector) → double precision | cosine distance
inner_product(vector, vector) → double precision | inner product
l2_distance(vector, vector) → double precision | Euclidean distance
l1_distance(vector, vector) → double precision | taxicab distance [unreleased]
normalize_l2(vector) → vector | normalize with Euclidean norm [unreleased]
vector_dims(vector) → integer | number of dimensions
vector_norm(vector) → double precision | Euclidean norm

View File

@@ -4,6 +4,9 @@
CREATE FUNCTION l1_distance(vector, vector) RETURNS float8
AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE;
CREATE FUNCTION normalize_l2(vector) RETURNS vector
AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE;
CREATE FUNCTION vector_mul(vector, vector) RETURNS vector
AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE;

View File

@@ -49,6 +49,9 @@ CREATE FUNCTION vector_dims(vector) RETURNS integer
CREATE FUNCTION vector_norm(vector) RETURNS float8
AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE;
CREATE FUNCTION normalize_l2(vector) RETURNS vector
AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE;
CREATE FUNCTION vector_add(vector, vector) RETURNS vector
AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE;

View File

@@ -745,6 +745,45 @@ vector_norm(PG_FUNCTION_ARGS)
PG_RETURN_FLOAT8(sqrt((double) norm));
}
/*
* Normalize a vector with the L2 norm
*/
PGDLLEXPORT PG_FUNCTION_INFO_V1(normalize_l2);
Datum
normalize_l2(PG_FUNCTION_ARGS)
{
Vector *a = PG_GETARG_VECTOR_P(0);
float *ax = a->x;
float norm = 0.0;
Vector *result;
float *rx;
result = InitVector(a->dim);
rx = result->x;
/* Auto-vectorized */
for (int i = 0; i < a->dim; i++)
norm += ax[i] * ax[i];
norm = sqrtf(norm);
if (norm > 0)
{
/* Auto-vectorized */
for (int i = 0, imax = a->dim; i < imax; i++)
rx[i] = ax[i] / norm;
/* Check for overflow */
for (int i = 0, imax = a->dim; i < imax; i++)
{
if (isinf(rx[i]))
float_overflow_error();
}
}
PG_RETURN_POINTER(result);
}
/*
* Add vectors
*/

View File

@@ -48,6 +48,30 @@ SELECT vector_norm('[0,1]');
1
(1 row)
SELECT normalize_l2('[3,4]');
normalize_l2
--------------
[0.6,0.8]
(1 row)
SELECT normalize_l2('[3,0]');
normalize_l2
--------------
[1,0]
(1 row)
SELECT normalize_l2('[0,0.1]');
normalize_l2
--------------
[0,1]
(1 row)
SELECT normalize_l2('[0,0]');
normalize_l2
--------------
[0,0]
(1 row)
SELECT l2_distance('[0,0]', '[3,4]');
l2_distance
-------------

View File

@@ -12,6 +12,11 @@ SELECT round(vector_norm('[1,1]')::numeric, 5);
SELECT vector_norm('[3,4]');
SELECT vector_norm('[0,1]');
SELECT normalize_l2('[3,4]');
SELECT normalize_l2('[3,0]');
SELECT normalize_l2('[0,0.1]');
SELECT normalize_l2('[0,0]');
SELECT l2_distance('[0,0]', '[3,4]');
SELECT l2_distance('[0,0]', '[0,1]');
SELECT l2_distance('[1,2]', '[3]');