diff --git a/CHANGELOG.md b/CHANGELOG.md index 67fc031..42eb112 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,7 @@ If upgrading with Postgres < 13, see [this note](https://github.com/pgvector/pgv - Changed storage for vector from `extended` to `external` - Improved performance of HNSW - Added support for parallel index builds for HNSW +- Added `hamming_distance` function - Added validation for GUC parameters - Reduced memory usage for HNSW index builds - Reduced WAL generation for HNSW index builds diff --git a/sql/vector--0.5.1--0.6.0.sql b/sql/vector--0.5.1--0.6.0.sql index 8e5af7f..9c5e6a7 100644 --- a/sql/vector--0.5.1--0.6.0.sql +++ b/sql/vector--0.5.1--0.6.0.sql @@ -3,3 +3,6 @@ -- remove this single line for Postgres < 13 ALTER TYPE vector SET (STORAGE = external); + +CREATE FUNCTION hamming_distance(bytea, bytea) RETURNS float8 + AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; diff --git a/sql/vector.sql b/sql/vector.sql index 4b17faa..a535720 100644 --- a/sql/vector.sql +++ b/sql/vector.sql @@ -290,3 +290,8 @@ CREATE OPERATOR CLASS vector_cosine_ops OPERATOR 1 <=> (vector, vector) FOR ORDER BY float_ops, FUNCTION 1 vector_negative_inner_product(vector, vector), FUNCTION 2 vector_norm(vector); + +-- bytea functions + +CREATE FUNCTION hamming_distance(bytea, bytea) RETURNS float8 + AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; diff --git a/src/vector.c b/src/vector.c index 982b176..562373d 100644 --- a/src/vector.c +++ b/src/vector.c @@ -10,6 +10,7 @@ #include "lib/stringinfo.h" #include "libpq/pqformat.h" #include "port.h" /* for strtof() */ +#include "port/pg_bitutils.h" #include "utils/array.h" #include "utils/builtins.h" #include "utils/float.h" @@ -1134,3 +1135,43 @@ vector_avg(PG_FUNCTION_ARGS) PG_RETURN_POINTER(result); } + +/* + * Ensure same number of bytes + */ +static inline void +CheckByteLengths(unsigned long aLen, unsigned long bLen) +{ + if (aLen != bLen) + ereport(ERROR, + (errcode(ERRCODE_DATA_EXCEPTION), + errmsg("different byte lengths %lu and %lu", aLen, bLen))); +} + +/* + * Get the hamming distance between two binary strings + */ +PGDLLEXPORT PG_FUNCTION_INFO_V1(hamming_distance); +Datum +hamming_distance(PG_FUNCTION_ARGS) +{ + bytea *a = PG_GETARG_BYTEA_PP(0); + bytea *b = PG_GETARG_BYTEA_PP(1); + char *ax = VARDATA_ANY(a); + char *bx = VARDATA_ANY(b); + unsigned long aLen = VARSIZE_ANY_EXHDR(a); + unsigned long bLen = VARSIZE_ANY_EXHDR(b); + uint64 distance = 0; + + CheckByteLengths(aLen, bLen); + + for (unsigned long i = 0; i < aLen; i++) + { + unsigned char diff = (unsigned char) (ax[i] ^ bx[i]); + + distance += pg_number_of_ones[diff]; + } + + /* TODO Decide on return type */ + PG_RETURN_FLOAT8((double) distance); +} diff --git a/test/expected/functions.out b/test/expected/functions.out index 2840688..84da2d8 100644 --- a/test/expected/functions.out +++ b/test/expected/functions.out @@ -158,6 +158,32 @@ SELECT l1_distance('[3e38]', '[-3e38]'); Infinity (1 row) +SELECT hamming_distance('\xFF', '\xFF'); + hamming_distance +------------------ + 0 +(1 row) + +SELECT hamming_distance('\xFF', '\xFE'); + hamming_distance +------------------ + 1 +(1 row) + +SELECT hamming_distance('\xFF', '\xFC'); + hamming_distance +------------------ + 2 +(1 row) + +SELECT hamming_distance('\xFF', '\x00'); + hamming_distance +------------------ + 8 +(1 row) + +SELECT hamming_distance('\xFF', '\x0000'); +ERROR: different byte lengths 1 and 2 SELECT avg(v) FROM unnest(ARRAY['[1,2,3]'::vector, '[3,5,7]']) v; avg ----------- diff --git a/test/sql/functions.sql b/test/sql/functions.sql index 914df36..fc657e2 100644 --- a/test/sql/functions.sql +++ b/test/sql/functions.sql @@ -37,6 +37,12 @@ SELECT l1_distance('[0,0]', '[0,1]'); SELECT l1_distance('[1,2]', '[3]'); SELECT l1_distance('[3e38]', '[-3e38]'); +SELECT hamming_distance('\xFF', '\xFF'); +SELECT hamming_distance('\xFF', '\xFE'); +SELECT hamming_distance('\xFF', '\xFC'); +SELECT hamming_distance('\xFF', '\x00'); +SELECT hamming_distance('\xFF', '\x0000'); + SELECT avg(v) FROM unnest(ARRAY['[1,2,3]'::vector, '[3,5,7]']) v; SELECT avg(v) FROM unnest(ARRAY['[1,2,3]'::vector, '[3,5,7]', NULL]) v; SELECT avg(v) FROM unnest(ARRAY[]::vector[]) v;