diff --git a/CHANGELOG.md b/CHANGELOG.md index 68f274a..69066d9 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,7 @@ - Added support for binary vectors to HNSW - Added `hamming_distance` function +- Added `jaccard_distance` function - Added `quantize_binary` function ## 0.6.2 (2024-03-18) diff --git a/README.md b/README.md index 73f7b91..1364297 100644 --- a/README.md +++ b/README.md @@ -728,6 +728,7 @@ Operator | Description | Added Function | Description | Added --- | --- | --- hamming_distance(bit, bit) → double precision | Hamming distance | 0.7.0 +jaccard_distance(bit, bit) → double precision | Jaccard distance | 0.7.0 ## Installation Notes - Linux and Mac diff --git a/sql/vector--0.6.2--0.7.0.sql b/sql/vector--0.6.2--0.7.0.sql index d385eaf..ba4504d 100644 --- a/sql/vector--0.6.2--0.7.0.sql +++ b/sql/vector--0.6.2--0.7.0.sql @@ -7,6 +7,9 @@ CREATE FUNCTION quantize_binary(vector) RETURNS bit CREATE FUNCTION hamming_distance(bit, bit) RETURNS float8 AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; +CREATE FUNCTION jaccard_distance(bit, bit) RETURNS float8 + AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + CREATE OPERATOR <~> ( LEFTARG = bit, RIGHTARG = bit, PROCEDURE = hamming_distance, COMMUTATOR = '<~>' diff --git a/sql/vector.sql b/sql/vector.sql index f37e93e..6edea66 100644 --- a/sql/vector.sql +++ b/sql/vector.sql @@ -296,6 +296,9 @@ CREATE OPERATOR CLASS vector_cosine_ops CREATE FUNCTION hamming_distance(bit, bit) RETURNS float8 AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; +CREATE FUNCTION jaccard_distance(bit, bit) RETURNS float8 + AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + CREATE OPERATOR <~> ( LEFTARG = bit, RIGHTARG = bit, PROCEDURE = hamming_distance, COMMUTATOR = '<~>' diff --git a/src/bitvector.c b/src/bitvector.c index 113d830..2c54d20 100644 --- a/src/bitvector.c +++ b/src/bitvector.c @@ -58,3 +58,32 @@ hamming_distance(PG_FUNCTION_ARGS) PG_RETURN_FLOAT8((double) distance); } + +/* + * Get the Jaccard distance between two bit strings + */ +PGDLLEXPORT PG_FUNCTION_INFO_V1(jaccard_distance); +Datum +jaccard_distance(PG_FUNCTION_ARGS) +{ + VarBit *a = PG_GETARG_VARBIT_P(0); + VarBit *b = PG_GETARG_VARBIT_P(1); + unsigned char *ax = VARBITS(a); + unsigned char *bx = VARBITS(b); + uint64 aa; + uint64 bb; + uint64 ab = 0; + + CheckBitLengths(VARBITLEN(a), VARBITLEN(b)); + + /* TODO Improve performance */ + aa = pg_popcount((char *) ax, VARBITBYTES(a)); + bb = pg_popcount((char *) bx, VARBITBYTES(b)); + for (uint32 i = 0; i < VARBITBYTES(a); i++) + ab += pg_number_of_ones[ax[i] & bx[i]]; + + if (ab == 0) + PG_RETURN_FLOAT8(1); + + PG_RETURN_FLOAT8(1 - (ab / ((double) (aa + bb - ab)))); +} diff --git a/test/expected/functions.out b/test/expected/functions.out index 6f32197..5c20b43 100644 --- a/test/expected/functions.out +++ b/test/expected/functions.out @@ -234,6 +234,38 @@ SELECT hamming_distance(B'111', B'000'); SELECT hamming_distance(B'111', B'00'); ERROR: different bit lengths 3 and 2 +SELECT jaccard_distance(B'1111', B'1111'); + jaccard_distance +------------------ + 0 +(1 row) + +SELECT jaccard_distance(B'1111', B'1110'); + jaccard_distance +------------------ + 0.25 +(1 row) + +SELECT jaccard_distance(B'1111', B'1100'); + jaccard_distance +------------------ + 0.5 +(1 row) + +SELECT jaccard_distance(B'1111', B'1000'); + jaccard_distance +------------------ + 0.75 +(1 row) + +SELECT jaccard_distance(B'1111', B'0000'); + jaccard_distance +------------------ + 1 +(1 row) + +SELECT jaccard_distance(B'1111', B'000'); +ERROR: different bit lengths 4 and 3 SELECT quantize_binary('[1,0,-1]'); quantize_binary ----------------- diff --git a/test/sql/functions.sql b/test/sql/functions.sql index b347390..0cac182 100644 --- a/test/sql/functions.sql +++ b/test/sql/functions.sql @@ -54,6 +54,13 @@ SELECT hamming_distance(B'111', B'100'); SELECT hamming_distance(B'111', B'000'); SELECT hamming_distance(B'111', B'00'); +SELECT jaccard_distance(B'1111', B'1111'); +SELECT jaccard_distance(B'1111', B'1110'); +SELECT jaccard_distance(B'1111', B'1100'); +SELECT jaccard_distance(B'1111', B'1000'); +SELECT jaccard_distance(B'1111', B'0000'); +SELECT jaccard_distance(B'1111', B'000'); + SELECT quantize_binary('[1,0,-1]'); SELECT quantize_binary('[0,0.1,-0.2,-0.3,0.4,0.5,0.6,-0.7,0.8,-0.9,1]');