From f3aec9fd03e5375d039eb95fe8bb0d3792ed13bd Mon Sep 17 00:00:00 2001 From: Andrew Kane Date: Mon, 25 Mar 2024 14:22:23 -0700 Subject: [PATCH] Added hamming_distance function --- CHANGELOG.md | 1 + sql/vector--0.6.2--0.6.3.sql | 5 +++++ sql/vector.sql | 5 +++++ src/vector.c | 36 ++++++++++++++++++++++++++++++++++++ test/expected/functions.out | 26 ++++++++++++++++++++++++++ test/sql/functions.sql | 6 ++++++ 6 files changed, 79 insertions(+) create mode 100644 sql/vector--0.6.2--0.6.3.sql diff --git a/CHANGELOG.md b/CHANGELOG.md index 9bcea19..f4022a2 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -13,6 +13,7 @@ If upgrading with Postgres 12 or Docker, see [these notes](https://github.com/pgvector/pgvector#060). - Added support for parallel index builds for HNSW +- Added `hamming_distance` function - Added validation for GUC parameters - Changed storage for vector from `extended` to `external` - Improved performance of HNSW diff --git a/sql/vector--0.6.2--0.6.3.sql b/sql/vector--0.6.2--0.6.3.sql new file mode 100644 index 0000000..5fcab73 --- /dev/null +++ b/sql/vector--0.6.2--0.6.3.sql @@ -0,0 +1,5 @@ +-- complain if script is sourced in psql, rather than via CREATE EXTENSION +\echo Use "ALTER EXTENSION vector UPDATE TO '0.6.3'" to load this file. \quit + +CREATE FUNCTION hamming_distance(bit, bit) RETURNS float8 + AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; diff --git a/sql/vector.sql b/sql/vector.sql index 141e83c..eb32566 100644 --- a/sql/vector.sql +++ b/sql/vector.sql @@ -287,3 +287,8 @@ CREATE OPERATOR CLASS vector_cosine_ops OPERATOR 1 <=> (vector, vector) FOR ORDER BY float_ops, FUNCTION 1 vector_negative_inner_product(vector, vector), FUNCTION 2 vector_norm(vector); + +-- bit functions + +CREATE FUNCTION hamming_distance(bit, bit) RETURNS float8 + AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; diff --git a/src/vector.c b/src/vector.c index 5f3cbbb..73231ff 100644 --- a/src/vector.c +++ b/src/vector.c @@ -10,11 +10,13 @@ #include "lib/stringinfo.h" #include "libpq/pqformat.h" #include "port.h" /* for strtof() */ +#include "port/pg_bitutils.h" #include "utils/array.h" #include "utils/builtins.h" #include "utils/float.h" #include "utils/lsyscache.h" #include "utils/numeric.h" +#include "utils/varbit.h" #include "vector.h" #if PG_VERSION_NUM >= 160000 @@ -1160,3 +1162,37 @@ vector_avg(PG_FUNCTION_ARGS) PG_RETURN_POINTER(result); } + +/* + * Ensure same number of bits + */ +static inline void +CheckBitLengths(uint32 aLen, uint32 bLen) +{ + if (aLen != bLen) + ereport(ERROR, + (errcode(ERRCODE_DATA_EXCEPTION), + errmsg("different bit lengths %u and %u", aLen, bLen))); +} + +/* + * Get the Hamming distance between two bit strings + */ +PGDLLEXPORT PG_FUNCTION_INFO_V1(hamming_distance); +Datum +hamming_distance(PG_FUNCTION_ARGS) +{ + VarBit *a = PG_GETARG_VARBIT_P(0); + VarBit *b = PG_GETARG_VARBIT_P(1); + unsigned char *ax = VARBITS(a); + unsigned char *bx = VARBITS(b); + uint64 distance = 0; + + CheckBitLengths(VARBITLEN(a), VARBITLEN(b)); + + for (int i = 0; i < VARBITBYTES(a); i++) + distance += pg_number_of_ones[ax[i] ^ bx[i]]; + + /* TODO Decide on return type */ + PG_RETURN_FLOAT8((double) distance); +} diff --git a/test/expected/functions.out b/test/expected/functions.out index 85d1a2f..2b5e740 100644 --- a/test/expected/functions.out +++ b/test/expected/functions.out @@ -208,6 +208,32 @@ SELECT l1_distance('[3e38]', '[-3e38]'); Infinity (1 row) +SELECT hamming_distance(B'111', B'111'); + hamming_distance +------------------ + 0 +(1 row) + +SELECT hamming_distance(B'111', B'110'); + hamming_distance +------------------ + 1 +(1 row) + +SELECT hamming_distance(B'111', B'100'); + hamming_distance +------------------ + 2 +(1 row) + +SELECT hamming_distance(B'111', B'000'); + hamming_distance +------------------ + 3 +(1 row) + +SELECT hamming_distance(B'111', B'00'); +ERROR: different bit lengths 3 and 2 SELECT avg(v) FROM unnest(ARRAY['[1,2,3]'::vector, '[3,5,7]']) v; avg ----------- diff --git a/test/sql/functions.sql b/test/sql/functions.sql index 6235684..81fe752 100644 --- a/test/sql/functions.sql +++ b/test/sql/functions.sql @@ -48,6 +48,12 @@ SELECT l1_distance('[0,0]', '[0,1]'); SELECT l1_distance('[1,2]', '[3]'); SELECT l1_distance('[3e38]', '[-3e38]'); +SELECT hamming_distance(B'111', B'111'); +SELECT hamming_distance(B'111', B'110'); +SELECT hamming_distance(B'111', B'100'); +SELECT hamming_distance(B'111', B'000'); +SELECT hamming_distance(B'111', B'00'); + SELECT avg(v) FROM unnest(ARRAY['[1,2,3]'::vector, '[3,5,7]']) v; SELECT avg(v) FROM unnest(ARRAY['[1,2,3]'::vector, '[3,5,7]', NULL]) v; SELECT avg(v) FROM unnest(ARRAY[]::vector[]) v;