From 8fcf77f89acca533779299f72ff4877854aeee77 Mon Sep 17 00:00:00 2001 From: Andrew Kane Date: Wed, 15 Nov 2023 15:37:27 -0800 Subject: [PATCH] Added support for bigint attributes [skip ci] --- sql/vector--0.5.1--0.6.0.sql | 8 ++++++++ sql/vector.sql | 8 ++++++++ src/hnsw.c | 14 ++++++++++++++ test/t/019_hnsw_filtering.pl | 2 +- 4 files changed, 31 insertions(+), 1 deletion(-) diff --git a/sql/vector--0.5.1--0.6.0.sql b/sql/vector--0.5.1--0.6.0.sql index 3030c31..8cee53d 100644 --- a/sql/vector--0.5.1--0.6.0.sql +++ b/sql/vector--0.5.1--0.6.0.sql @@ -4,7 +4,15 @@ CREATE FUNCTION hnsw_attribute_distance(integer, integer) RETURNS float8 AS 'MODULE_PATHNAME', 'hnsw_int4_attribute_distance' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; +CREATE FUNCTION hnsw_attribute_distance(bigint, bigint) RETURNS float8 + AS 'MODULE_PATHNAME', 'hnsw_int8_attribute_distance' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + CREATE OPERATOR CLASS vector_integer_ops DEFAULT FOR TYPE integer USING hnsw AS OPERATOR 2 = (integer, integer), FUNCTION 3 hnsw_attribute_distance(integer, integer); + +CREATE OPERATOR CLASS vector_bigint_ops + DEFAULT FOR TYPE bigint USING hnsw AS + OPERATOR 2 = (bigint, bigint), + FUNCTION 3 hnsw_attribute_distance(bigint, bigint); diff --git a/sql/vector.sql b/sql/vector.sql index 9149369..53e7c80 100644 --- a/sql/vector.sql +++ b/sql/vector.sql @@ -296,7 +296,15 @@ CREATE OPERATOR CLASS vector_cosine_ops CREATE FUNCTION hnsw_attribute_distance(integer, integer) RETURNS float8 AS 'MODULE_PATHNAME', 'hnsw_int4_attribute_distance' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; +CREATE FUNCTION hnsw_attribute_distance(bigint, bigint) RETURNS float8 + AS 'MODULE_PATHNAME', 'hnsw_int8_attribute_distance' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + CREATE OPERATOR CLASS vector_integer_ops DEFAULT FOR TYPE integer USING hnsw AS OPERATOR 2 = (integer, integer), FUNCTION 3 hnsw_attribute_distance(integer, integer); + +CREATE OPERATOR CLASS vector_bigint_ops + DEFAULT FOR TYPE bigint USING hnsw AS + OPERATOR 2 = (bigint, bigint), + FUNCTION 3 hnsw_attribute_distance(bigint, bigint); diff --git a/src/hnsw.c b/src/hnsw.c index 88c8e7d..96f0370 100644 --- a/src/hnsw.c +++ b/src/hnsw.c @@ -236,3 +236,17 @@ hnsw_int4_attribute_distance(PG_FUNCTION_ARGS) PG_RETURN_FLOAT8(distance); } + +/* + * Get the distance between two int8 attributes + */ +PGDLLEXPORT PG_FUNCTION_INFO_V1(hnsw_int8_attribute_distance); +Datum +hnsw_int8_attribute_distance(PG_FUNCTION_ARGS) +{ + int64 a = PG_GETARG_INT64(0); + int64 b = PG_GETARG_INT64(1); + double distance = ((double) a) - ((double) b); + + PG_RETURN_FLOAT8(distance); +} diff --git a/test/t/019_hnsw_filtering.pl b/test/t/019_hnsw_filtering.pl index fb7b3a6..563c5a0 100644 --- a/test/t/019_hnsw_filtering.pl +++ b/test/t/019_hnsw_filtering.pl @@ -56,7 +56,7 @@ $node->start; # Create table $node->safe_psql("postgres", "CREATE EXTENSION vector;"); -$node->safe_psql("postgres", "CREATE TABLE tst (i int4, v vector($dim), c int4);"); +$node->safe_psql("postgres", "CREATE TABLE tst (i int4, v vector($dim), c int8);"); $node->safe_psql("postgres", "INSERT INTO tst SELECT i, ARRAY[$array_sql], i % $nc FROM generate_series(1, 20000) i;" );