mirror of
https://github.com/pgvector/pgvector.git
synced 2026-06-06 05:51:21 +08:00
Added l2_normalize function for sparsevec
This commit is contained in:
@@ -952,6 +952,7 @@ inner_product(sparsevec, sparsevec) → double precision | inner product | unrel
|
|||||||
l1_distance(sparsevec, sparsevec) → double precision | taxicab distance | unreleased
|
l1_distance(sparsevec, sparsevec) → double precision | taxicab distance | unreleased
|
||||||
l2_distance(sparsevec, sparsevec) → double precision | Euclidean distance | unreleased
|
l2_distance(sparsevec, sparsevec) → double precision | Euclidean distance | unreleased
|
||||||
l2_norm(sparsevec) → double precision | Euclidean norm | unreleased
|
l2_norm(sparsevec) → double precision | Euclidean norm | unreleased
|
||||||
|
l2_normalize(sparsevec) → sparsevec | Normalize with Euclidean norm | unreleased
|
||||||
|
|
||||||
## Installation Notes - Linux and Mac
|
## Installation Notes - Linux and Mac
|
||||||
|
|
||||||
|
|||||||
@@ -364,6 +364,9 @@ CREATE FUNCTION l1_distance(sparsevec, sparsevec) RETURNS float8
|
|||||||
CREATE FUNCTION l2_norm(sparsevec) RETURNS float8
|
CREATE FUNCTION l2_norm(sparsevec) RETURNS float8
|
||||||
AS 'MODULE_PATHNAME', 'sparsevec_l2_norm' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE;
|
AS 'MODULE_PATHNAME', 'sparsevec_l2_norm' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE;
|
||||||
|
|
||||||
|
CREATE FUNCTION l2_normalize(sparsevec) RETURNS sparsevec
|
||||||
|
AS 'MODULE_PATHNAME', 'sparsevec_l2_normalize' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE;
|
||||||
|
|
||||||
CREATE FUNCTION sparsevec_lt(sparsevec, sparsevec) RETURNS bool
|
CREATE FUNCTION sparsevec_lt(sparsevec, sparsevec) RETURNS bool
|
||||||
AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE;
|
AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE;
|
||||||
|
|
||||||
|
|||||||
@@ -673,6 +673,9 @@ CREATE FUNCTION l1_distance(sparsevec, sparsevec) RETURNS float8
|
|||||||
CREATE FUNCTION l2_norm(sparsevec) RETURNS float8
|
CREATE FUNCTION l2_norm(sparsevec) RETURNS float8
|
||||||
AS 'MODULE_PATHNAME', 'sparsevec_l2_norm' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE;
|
AS 'MODULE_PATHNAME', 'sparsevec_l2_norm' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE;
|
||||||
|
|
||||||
|
CREATE FUNCTION l2_normalize(sparsevec) RETURNS sparsevec
|
||||||
|
AS 'MODULE_PATHNAME', 'sparsevec_l2_normalize' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE;
|
||||||
|
|
||||||
-- sparsevec private functions
|
-- sparsevec private functions
|
||||||
|
|
||||||
CREATE FUNCTION sparsevec_lt(sparsevec, sparsevec) RETURNS bool
|
CREATE FUNCTION sparsevec_lt(sparsevec, sparsevec) RETURNS bool
|
||||||
|
|||||||
@@ -213,20 +213,7 @@ HnswNormValue(FmgrInfo *procinfo, Oid collation, Datum *value, HnswType type)
|
|||||||
else if (type == HNSW_TYPE_HALFVEC)
|
else if (type == HNSW_TYPE_HALFVEC)
|
||||||
*value = DirectFunctionCall1(halfvec_l2_normalize, *value);
|
*value = DirectFunctionCall1(halfvec_l2_normalize, *value);
|
||||||
else if (type == HNSW_TYPE_SPARSEVEC)
|
else if (type == HNSW_TYPE_SPARSEVEC)
|
||||||
{
|
*value = DirectFunctionCall1(sparsevec_l2_normalize, *value);
|
||||||
SparseVector *v = DatumGetSparseVector(*value);
|
|
||||||
SparseVector *result = InitSparseVector(v->dim, v->nnz);
|
|
||||||
float *vx = SPARSEVEC_VALUES(v);
|
|
||||||
float *rx = SPARSEVEC_VALUES(result);
|
|
||||||
|
|
||||||
for (int i = 0; i < v->nnz; i++)
|
|
||||||
{
|
|
||||||
result->indices[i] = v->indices[i];
|
|
||||||
rx[i] = vx[i] / norm;
|
|
||||||
}
|
|
||||||
|
|
||||||
*value = PointerGetDatum(result);
|
|
||||||
}
|
|
||||||
else
|
else
|
||||||
elog(ERROR, "Unsupported type");
|
elog(ERROR, "Unsupported type");
|
||||||
|
|
||||||
|
|||||||
@@ -848,6 +848,48 @@ sparsevec_l2_norm(PG_FUNCTION_ARGS)
|
|||||||
PG_RETURN_FLOAT8(sqrt(norm));
|
PG_RETURN_FLOAT8(sqrt(norm));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Normalize a sparse vector with the L2 norm
|
||||||
|
*/
|
||||||
|
PGDLLEXPORT PG_FUNCTION_INFO_V1(sparsevec_l2_normalize);
|
||||||
|
Datum
|
||||||
|
sparsevec_l2_normalize(PG_FUNCTION_ARGS)
|
||||||
|
{
|
||||||
|
SparseVector *a = PG_GETARG_SPARSEVEC_P(0);
|
||||||
|
float *ax = SPARSEVEC_VALUES(a);
|
||||||
|
double norm = 0;
|
||||||
|
SparseVector *result;
|
||||||
|
float *rx;
|
||||||
|
|
||||||
|
result = InitSparseVector(a->dim, a->nnz);
|
||||||
|
rx = SPARSEVEC_VALUES(result);
|
||||||
|
|
||||||
|
/* Auto-vectorized */
|
||||||
|
for (int i = 0; i < a->nnz; i++)
|
||||||
|
norm += (double) ax[i] * (double) ax[i];
|
||||||
|
|
||||||
|
norm = sqrt(norm);
|
||||||
|
|
||||||
|
/* Return zero vector for zero norm */
|
||||||
|
if (norm > 0)
|
||||||
|
{
|
||||||
|
for (int i = 0; i < a->nnz; i++)
|
||||||
|
{
|
||||||
|
result->indices[i] = a->indices[i];
|
||||||
|
rx[i] = ax[i] / norm;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Check for overflow */
|
||||||
|
for (int i = 0; i < a->nnz; i++)
|
||||||
|
{
|
||||||
|
if (isinf(rx[i]))
|
||||||
|
float_overflow_error();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
PG_RETURN_POINTER(result);
|
||||||
|
}
|
||||||
|
|
||||||
/*
|
/*
|
||||||
* Internal helper to compare sparse vectors
|
* Internal helper to compare sparse vectors
|
||||||
*/
|
*/
|
||||||
|
|||||||
@@ -1,6 +1,8 @@
|
|||||||
#ifndef SPARSEVEC_H
|
#ifndef SPARSEVEC_H
|
||||||
#define SPARSEVEC_H
|
#define SPARSEVEC_H
|
||||||
|
|
||||||
|
#include "fmgr.h"
|
||||||
|
|
||||||
#define SPARSEVEC_MAX_DIM 100000
|
#define SPARSEVEC_MAX_DIM 100000
|
||||||
#define SPARSEVEC_MAX_NNZ 16000
|
#define SPARSEVEC_MAX_NNZ 16000
|
||||||
|
|
||||||
@@ -21,5 +23,6 @@ typedef struct SparseVector
|
|||||||
} SparseVector;
|
} SparseVector;
|
||||||
|
|
||||||
SparseVector *InitSparseVector(int dim, int nnz);
|
SparseVector *InitSparseVector(int dim, int nnz);
|
||||||
|
Datum sparsevec_l2_normalize(PG_FUNCTION_ARGS);
|
||||||
|
|
||||||
#endif
|
#endif
|
||||||
|
|||||||
@@ -292,3 +292,33 @@ SELECT l1_distance('{1:1,3:3,5:5,7:7,9:9}/9'::sparsevec, '{2:2,4:4,6:6,8:8}/9');
|
|||||||
45
|
45
|
||||||
(1 row)
|
(1 row)
|
||||||
|
|
||||||
|
SELECT l2_normalize('{1:3,2:4}/2'::sparsevec);
|
||||||
|
l2_normalize
|
||||||
|
-----------------
|
||||||
|
{1:0.6,2:0.8}/2
|
||||||
|
(1 row)
|
||||||
|
|
||||||
|
SELECT l2_normalize('{1:3}/2'::sparsevec);
|
||||||
|
l2_normalize
|
||||||
|
--------------
|
||||||
|
{1:1}/2
|
||||||
|
(1 row)
|
||||||
|
|
||||||
|
SELECT l2_normalize('{2:0.1}/2'::sparsevec);
|
||||||
|
l2_normalize
|
||||||
|
--------------
|
||||||
|
{2:1}/2
|
||||||
|
(1 row)
|
||||||
|
|
||||||
|
SELECT l2_normalize('{}/2'::sparsevec);
|
||||||
|
l2_normalize
|
||||||
|
--------------
|
||||||
|
{}/2
|
||||||
|
(1 row)
|
||||||
|
|
||||||
|
SELECT l2_normalize('{1:3e38}/1'::sparsevec);
|
||||||
|
l2_normalize
|
||||||
|
--------------
|
||||||
|
{1:1}/1
|
||||||
|
(1 row)
|
||||||
|
|
||||||
|
|||||||
@@ -55,3 +55,9 @@ SELECT l1_distance('{1:1,2:2}/2'::sparsevec, '{1:3}/1');
|
|||||||
SELECT l1_distance('{1:3e38}/1'::sparsevec, '{1:-3e38}/1');
|
SELECT l1_distance('{1:3e38}/1'::sparsevec, '{1:-3e38}/1');
|
||||||
SELECT l1_distance('{1:1,3:3,5:5,7:7}/8'::sparsevec, '{2:2,4:4,6:6,8:8}/8');
|
SELECT l1_distance('{1:1,3:3,5:5,7:7}/8'::sparsevec, '{2:2,4:4,6:6,8:8}/8');
|
||||||
SELECT l1_distance('{1:1,3:3,5:5,7:7,9:9}/9'::sparsevec, '{2:2,4:4,6:6,8:8}/9');
|
SELECT l1_distance('{1:1,3:3,5:5,7:7,9:9}/9'::sparsevec, '{2:2,4:4,6:6,8:8}/9');
|
||||||
|
|
||||||
|
SELECT l2_normalize('{1:3,2:4}/2'::sparsevec);
|
||||||
|
SELECT l2_normalize('{1:3}/2'::sparsevec);
|
||||||
|
SELECT l2_normalize('{2:0.1}/2'::sparsevec);
|
||||||
|
SELECT l2_normalize('{}/2'::sparsevec);
|
||||||
|
SELECT l2_normalize('{1:3e38}/1'::sparsevec);
|
||||||
|
|||||||
Reference in New Issue
Block a user