mirror of
https://github.com/pgvector/pgvector.git
synced 2026-06-06 05:51:21 +08:00
Exclude zero vectors for cosine distance to be consistent with other types [skip ci]
This commit is contained in:
@@ -39,6 +39,9 @@ CREATE FUNCTION cosine_distance(intvec, intvec) RETURNS float8
|
||||
CREATE FUNCTION l1_distance(intvec, intvec) RETURNS float8
|
||||
AS 'MODULE_PATHNAME', 'intvec_l1_distance' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE;
|
||||
|
||||
CREATE FUNCTION l2_norm(intvec) RETURNS float8
|
||||
AS 'MODULE_PATHNAME', 'intvec_l2_norm' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE;
|
||||
|
||||
CREATE FUNCTION intvec_l2_squared_distance(intvec, intvec) RETURNS float8
|
||||
AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE;
|
||||
|
||||
@@ -85,4 +88,5 @@ CREATE OPERATOR CLASS intvec_ip_ops
|
||||
CREATE OPERATOR CLASS intvec_cosine_ops
|
||||
FOR TYPE intvec USING hnsw AS
|
||||
OPERATOR 1 <=> (intvec, intvec) FOR ORDER BY float_ops,
|
||||
FUNCTION 1 cosine_distance(intvec, intvec);
|
||||
FUNCTION 1 cosine_distance(intvec, intvec),
|
||||
FUNCTION 2 l2_norm(intvec);
|
||||
|
||||
@@ -330,6 +330,9 @@ CREATE FUNCTION cosine_distance(intvec, intvec) RETURNS float8
|
||||
CREATE FUNCTION l1_distance(intvec, intvec) RETURNS float8
|
||||
AS 'MODULE_PATHNAME', 'intvec_l1_distance' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE;
|
||||
|
||||
CREATE FUNCTION l2_norm(intvec) RETURNS float8
|
||||
AS 'MODULE_PATHNAME', 'intvec_l2_norm' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE;
|
||||
|
||||
-- intvec private functions
|
||||
|
||||
CREATE FUNCTION intvec_l2_squared_distance(intvec, intvec) RETURNS float8
|
||||
@@ -386,4 +389,5 @@ CREATE OPERATOR CLASS intvec_ip_ops
|
||||
CREATE OPERATOR CLASS intvec_cosine_ops
|
||||
FOR TYPE intvec USING hnsw AS
|
||||
OPERATOR 1 <=> (intvec, intvec) FOR ORDER BY float_ops,
|
||||
FUNCTION 1 cosine_distance(intvec, intvec);
|
||||
FUNCTION 1 cosine_distance(intvec, intvec),
|
||||
FUNCTION 2 l2_norm(intvec);
|
||||
|
||||
@@ -204,6 +204,10 @@ HnswNormValue(FmgrInfo *procinfo, Oid collation, Datum *value, HnswType type)
|
||||
|
||||
*value = PointerGetDatum(result);
|
||||
}
|
||||
else if (type == HNSW_TYPE_INTVEC)
|
||||
{
|
||||
/* Do nothing */
|
||||
}
|
||||
else
|
||||
elog(ERROR, "Unsupported type");
|
||||
|
||||
|
||||
18
src/intvec.c
18
src/intvec.c
@@ -588,3 +588,21 @@ intvec_l1_distance(PG_FUNCTION_ARGS)
|
||||
|
||||
PG_RETURN_FLOAT8((double) distance);
|
||||
}
|
||||
|
||||
/*
|
||||
* Get the L2 norm of an int vector
|
||||
*/
|
||||
PGDLLEXPORT PG_FUNCTION_INFO_V1(intvec_l2_norm);
|
||||
Datum
|
||||
intvec_l2_norm(PG_FUNCTION_ARGS)
|
||||
{
|
||||
IntVector *a = PG_GETARG_INTVEC_P(0);
|
||||
int8 *ax = a->x;
|
||||
int norm = 0;
|
||||
|
||||
/* Auto-vectorized */
|
||||
for (int i = 0; i < a->dim; i++)
|
||||
norm += ax[i] * ax[i];
|
||||
|
||||
PG_RETURN_FLOAT8(sqrt((double) norm));
|
||||
}
|
||||
|
||||
@@ -9,19 +9,18 @@ SELECT * FROM t ORDER BY val <=> '[3,3,3]';
|
||||
[1,1,1]
|
||||
[1,2,3]
|
||||
[1,2,4]
|
||||
[0,0,0]
|
||||
(4 rows)
|
||||
(3 rows)
|
||||
|
||||
SELECT COUNT(*) FROM (SELECT * FROM t ORDER BY val <=> '[0,0,0]') t2;
|
||||
count
|
||||
-------
|
||||
4
|
||||
3
|
||||
(1 row)
|
||||
|
||||
SELECT COUNT(*) FROM (SELECT * FROM t ORDER BY val <=> (SELECT NULL::intvec)) t2;
|
||||
count
|
||||
-------
|
||||
4
|
||||
3
|
||||
(1 row)
|
||||
|
||||
DROP TABLE t;
|
||||
|
||||
Reference in New Issue
Block a user