From af9d50481d80399b4f80a2392c699ab00deec098 Mon Sep 17 00:00:00 2001 From: Andrew Kane Date: Mon, 22 Apr 2024 12:44:03 -0700 Subject: [PATCH] Added support for indexing L1 distance --- CHANGELOG.md | 1 + README.md | 13 +++++++++++++ sql/vector--0.6.2--0.7.0.sql | 16 ++++++++++++++++ sql/vector.sql | 16 ++++++++++++++++ test/expected/hnsw_l1.out | 21 +++++++++++++++++++++ test/expected/ivfflat_l1.out | 21 +++++++++++++++++++++ test/sql/hnsw_l1.sql | 12 ++++++++++++ test/sql/ivfflat_l1.sql | 12 ++++++++++++ test/t/003_ivfflat_build_recall.pl | 4 ++-- test/t/005_ivfflat_query_recall.pl | 4 ++-- test/t/012_hnsw_build_recall.pl | 4 ++-- test/t/013_hnsw_insert_recall.pl | 4 ++-- test/t/017_ivfflat_insert_recall.pl | 4 ++-- 13 files changed, 122 insertions(+), 10 deletions(-) create mode 100644 test/expected/hnsw_l1.out create mode 100644 test/expected/ivfflat_l1.out create mode 100644 test/sql/hnsw_l1.sql create mode 100644 test/sql/ivfflat_l1.sql diff --git a/CHANGELOG.md b/CHANGELOG.md index 4c4e8b3..ea6e1e7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -3,6 +3,7 @@ - Added `halfvec` type - Added `sparsevec` type - Added support for indexing `bit` type +- Added support for indexing L1 distance - Added `binary_quantize` function - Added `hamming_distance` function - Added `jaccard_distance` function diff --git a/README.md b/README.md index 7bedb9b..458c3eb 100644 --- a/README.md +++ b/README.md @@ -227,6 +227,12 @@ Cosine distance CREATE INDEX ON items USING hnsw (embedding vector_cosine_ops); ``` +L1 distance - unreleased + +```sql +CREATE INDEX ON items USING hnsw (embedding vector_l1_ops); +``` + Hamming distance - unreleased ```sql @@ -349,6 +355,12 @@ Cosine distance CREATE INDEX ON items USING ivfflat (embedding vector_cosine_ops) WITH (lists = 100); ``` +L1 distance - unreleased + +```sql +CREATE INDEX ON items USING ivfflat (embedding vector_l1_ops) WITH (lists = 100); +``` + Supported types are: - `vector` - up to 2,000 dimensions @@ -855,6 +867,7 @@ Operator | Description | Added <-> | Euclidean distance | <#> | negative inner product | <=> | cosine distance | +<+> | taxicab distance | unreleased ### Vector Functions diff --git a/sql/vector--0.6.2--0.7.0.sql b/sql/vector--0.6.2--0.7.0.sql index f9483fc..9227487 100644 --- a/sql/vector--0.6.2--0.7.0.sql +++ b/sql/vector--0.6.2--0.7.0.sql @@ -13,10 +13,26 @@ CREATE FUNCTION subvector(vector, int, int) RETURNS vector CREATE FUNCTION vector_concat(vector, vector) RETURNS vector AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; +CREATE OPERATOR <+> ( + LEFTARG = vector, RIGHTARG = vector, PROCEDURE = l1_distance, + COMMUTATOR = '<+>' +); + CREATE OPERATOR || ( LEFTARG = vector, RIGHTARG = vector, PROCEDURE = vector_concat ); +CREATE OPERATOR CLASS vector_l1_ops + FOR TYPE vector USING ivfflat AS + OPERATOR 1 <+> (vector, vector) FOR ORDER BY float_ops, + FUNCTION 1 l1_distance(vector, vector), + FUNCTION 3 l1_distance(vector, vector); + +CREATE OPERATOR CLASS vector_l1_ops + FOR TYPE vector USING hnsw AS + OPERATOR 1 <+> (vector, vector) FOR ORDER BY float_ops, + FUNCTION 1 l1_distance(vector, vector); + CREATE FUNCTION hamming_distance(bit, bit) RETURNS float8 AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; diff --git a/sql/vector.sql b/sql/vector.sql index 6dcdbd7..a039620 100644 --- a/sql/vector.sql +++ b/sql/vector.sql @@ -186,6 +186,11 @@ CREATE OPERATOR <=> ( COMMUTATOR = '<=>' ); +CREATE OPERATOR <+> ( + LEFTARG = vector, RIGHTARG = vector, PROCEDURE = l1_distance, + COMMUTATOR = '<+>' +); + CREATE OPERATOR + ( LEFTARG = vector, RIGHTARG = vector, PROCEDURE = vector_add, COMMUTATOR = + @@ -288,6 +293,12 @@ CREATE OPERATOR CLASS vector_cosine_ops FUNCTION 3 vector_spherical_distance(vector, vector), FUNCTION 4 vector_norm(vector); +CREATE OPERATOR CLASS vector_l1_ops + FOR TYPE vector USING ivfflat AS + OPERATOR 1 <+> (vector, vector) FOR ORDER BY float_ops, + FUNCTION 1 l1_distance(vector, vector), + FUNCTION 3 l1_distance(vector, vector); + CREATE OPERATOR CLASS vector_l2_ops FOR TYPE vector USING hnsw AS OPERATOR 1 <-> (vector, vector) FOR ORDER BY float_ops, @@ -304,6 +315,11 @@ CREATE OPERATOR CLASS vector_cosine_ops FUNCTION 1 vector_negative_inner_product(vector, vector), FUNCTION 2 vector_norm(vector); +CREATE OPERATOR CLASS vector_l1_ops + FOR TYPE vector USING hnsw AS + OPERATOR 1 <+> (vector, vector) FOR ORDER BY float_ops, + FUNCTION 1 l1_distance(vector, vector); + -- bit functions CREATE FUNCTION hamming_distance(bit, bit) RETURNS float8 diff --git a/test/expected/hnsw_l1.out b/test/expected/hnsw_l1.out new file mode 100644 index 0000000..28473f6 --- /dev/null +++ b/test/expected/hnsw_l1.out @@ -0,0 +1,21 @@ +SET enable_seqscan = off; +CREATE TABLE t (val vector(3)); +INSERT INTO t (val) VALUES ('[0,0,0]'), ('[1,2,3]'), ('[1,1,1]'), (NULL); +CREATE INDEX ON t USING hnsw (val vector_l1_ops); +INSERT INTO t (val) VALUES ('[1,2,4]'); +SELECT * FROM t ORDER BY val <+> '[3,3,3]'; + val +--------- + [1,2,3] + [1,2,4] + [1,1,1] + [0,0,0] +(4 rows) + +SELECT COUNT(*) FROM (SELECT * FROM t ORDER BY val <+> (SELECT NULL::vector)) t2; + count +------- + 4 +(1 row) + +DROP TABLE t; diff --git a/test/expected/ivfflat_l1.out b/test/expected/ivfflat_l1.out new file mode 100644 index 0000000..abd0050 --- /dev/null +++ b/test/expected/ivfflat_l1.out @@ -0,0 +1,21 @@ +SET enable_seqscan = off; +CREATE TABLE t (val vector(3)); +INSERT INTO t (val) VALUES ('[0,0,0]'), ('[1,2,3]'), ('[1,1,1]'), (NULL); +CREATE INDEX ON t USING ivfflat (val vector_l1_ops) WITH (lists = 1); +INSERT INTO t (val) VALUES ('[1,2,4]'); +SELECT * FROM t ORDER BY val <+> '[3,3,3]'; + val +--------- + [1,2,3] + [1,2,4] + [1,1,1] + [0,0,0] +(4 rows) + +SELECT COUNT(*) FROM (SELECT * FROM t ORDER BY val <+> (SELECT NULL::vector)) t2; + count +------- + 4 +(1 row) + +DROP TABLE t; diff --git a/test/sql/hnsw_l1.sql b/test/sql/hnsw_l1.sql new file mode 100644 index 0000000..0d52d0a --- /dev/null +++ b/test/sql/hnsw_l1.sql @@ -0,0 +1,12 @@ +SET enable_seqscan = off; + +CREATE TABLE t (val vector(3)); +INSERT INTO t (val) VALUES ('[0,0,0]'), ('[1,2,3]'), ('[1,1,1]'), (NULL); +CREATE INDEX ON t USING hnsw (val vector_l1_ops); + +INSERT INTO t (val) VALUES ('[1,2,4]'); + +SELECT * FROM t ORDER BY val <+> '[3,3,3]'; +SELECT COUNT(*) FROM (SELECT * FROM t ORDER BY val <+> (SELECT NULL::vector)) t2; + +DROP TABLE t; diff --git a/test/sql/ivfflat_l1.sql b/test/sql/ivfflat_l1.sql new file mode 100644 index 0000000..d09c3d6 --- /dev/null +++ b/test/sql/ivfflat_l1.sql @@ -0,0 +1,12 @@ +SET enable_seqscan = off; + +CREATE TABLE t (val vector(3)); +INSERT INTO t (val) VALUES ('[0,0,0]'), ('[1,2,3]'), ('[1,1,1]'), (NULL); +CREATE INDEX ON t USING ivfflat (val vector_l1_ops) WITH (lists = 1); + +INSERT INTO t (val) VALUES ('[1,2,4]'); + +SELECT * FROM t ORDER BY val <+> '[3,3,3]'; +SELECT COUNT(*) FROM (SELECT * FROM t ORDER BY val <+> (SELECT NULL::vector)) t2; + +DROP TABLE t; diff --git a/test/t/003_ivfflat_build_recall.pl b/test/t/003_ivfflat_build_recall.pl index 21e0c8d..e691deb 100644 --- a/test/t/003_ivfflat_build_recall.pl +++ b/test/t/003_ivfflat_build_recall.pl @@ -70,8 +70,8 @@ for (1 .. 20) } # Check each index type -my @operators = ("<->", "<#>", "<=>"); -my @opclasses = ("vector_l2_ops", "vector_ip_ops", "vector_cosine_ops"); +my @operators = ("<->", "<#>", "<=>", "<+>"); +my @opclasses = ("vector_l2_ops", "vector_ip_ops", "vector_cosine_ops", "vector_l1_ops"); for my $i (0 .. $#operators) { diff --git a/test/t/005_ivfflat_query_recall.pl b/test/t/005_ivfflat_query_recall.pl index 1edebb3..93fe762 100644 --- a/test/t/005_ivfflat_query_recall.pl +++ b/test/t/005_ivfflat_query_recall.pl @@ -17,8 +17,8 @@ $node->safe_psql("postgres", ); # Check each index type -my @operators = ("<->", "<#>", "<=>"); -my @opclasses = ("vector_l2_ops", "vector_ip_ops", "vector_cosine_ops"); +my @operators = ("<->", "<#>", "<=>", "<+>"); +my @opclasses = ("vector_l2_ops", "vector_ip_ops", "vector_cosine_ops", "vector_l1_ops"); for my $i (0 .. $#operators) { diff --git a/test/t/012_hnsw_build_recall.pl b/test/t/012_hnsw_build_recall.pl index 163a472..1298009 100644 --- a/test/t/012_hnsw_build_recall.pl +++ b/test/t/012_hnsw_build_recall.pl @@ -67,8 +67,8 @@ for (1 .. 20) } # Check each index type -my @operators = ("<->", "<#>", "<=>"); -my @opclasses = ("vector_l2_ops", "vector_ip_ops", "vector_cosine_ops"); +my @operators = ("<->", "<#>", "<=>", "<+>"); +my @opclasses = ("vector_l2_ops", "vector_ip_ops", "vector_cosine_ops", "vector_l1_ops"); for my $i (0 .. $#operators) { diff --git a/test/t/013_hnsw_insert_recall.pl b/test/t/013_hnsw_insert_recall.pl index 2c87108..006231b 100644 --- a/test/t/013_hnsw_insert_recall.pl +++ b/test/t/013_hnsw_insert_recall.pl @@ -64,8 +64,8 @@ for (1 .. 20) } # Check each index type -my @operators = ("<->", "<#>", "<=>"); -my @opclasses = ("vector_l2_ops", "vector_ip_ops", "vector_cosine_ops"); +my @operators = ("<->", "<#>", "<=>", "<+>"); +my @opclasses = ("vector_l2_ops", "vector_ip_ops", "vector_cosine_ops", "vector_l1_ops"); for my $i (0 .. $#operators) { diff --git a/test/t/017_ivfflat_insert_recall.pl b/test/t/017_ivfflat_insert_recall.pl index c2e320c..148ccc9 100644 --- a/test/t/017_ivfflat_insert_recall.pl +++ b/test/t/017_ivfflat_insert_recall.pl @@ -66,8 +66,8 @@ for (1 .. 20) } # Check each index type -my @operators = ("<->", "<#>", "<=>"); -my @opclasses = ("vector_l2_ops", "vector_ip_ops", "vector_cosine_ops"); +my @operators = ("<->", "<#>", "<=>", "<+>"); +my @opclasses = ("vector_l2_ops", "vector_ip_ops", "vector_cosine_ops", "vector_l1_ops"); for my $i (0 .. $#operators) {