From 94a444f02958f04946fc2dfccf0e10fe3718b660 Mon Sep 17 00:00:00 2001 From: Andrew Kane Date: Mon, 1 Apr 2024 20:30:55 -0700 Subject: [PATCH] Added support for bit vectors to HNSW --- CHANGELOG.md | 7 ++ Makefile | 2 +- Makefile.win | 2 +- README.md | 36 +++++++- sql/vector--0.6.2--0.7.0.sql | 31 +++++++ sql/vector.sql | 31 +++++++ src/bitvector.c | 90 ++++++++++++++++++ src/bitvector.h | 8 ++ src/hnsw.h | 3 +- src/hnswbuild.c | 19 +++- src/hnswscan.c | 2 + src/vector.c | 20 ++++ test/expected/bit_functions.out | 64 +++++++++++++ test/expected/functions.out | 12 +++ test/expected/hnsw_hamming.out | 21 +++++ test/expected/hnsw_jaccard.out | 21 +++++ test/sql/bit_functions.sql | 13 +++ test/sql/functions.sql | 3 + test/sql/hnsw_hamming.sql | 12 +++ test/sql/hnsw_jaccard.sql | 12 +++ test/t/020_hnsw_bit_build_recall.pl | 137 ++++++++++++++++++++++++++++ 21 files changed, 541 insertions(+), 5 deletions(-) create mode 100644 sql/vector--0.6.2--0.7.0.sql create mode 100644 src/bitvector.c create mode 100644 src/bitvector.h create mode 100644 test/expected/bit_functions.out create mode 100644 test/expected/hnsw_hamming.out create mode 100644 test/expected/hnsw_jaccard.out create mode 100644 test/sql/bit_functions.sql create mode 100644 test/sql/hnsw_hamming.sql create mode 100644 test/sql/hnsw_jaccard.sql create mode 100644 test/t/020_hnsw_bit_build_recall.pl diff --git a/CHANGELOG.md b/CHANGELOG.md index 9bcea19..d67307c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,10 @@ +## 0.7.0 (unreleased) + +- Added support for bit vectors to HNSW +- Added `hamming_distance` function +- Added `jaccard_distance` function +- Added `quantize_binary` function + ## 0.6.2 (2024-03-18) - Reduced lock contention with parallel HNSW index builds diff --git a/Makefile b/Makefile index dff5232..04758d0 100644 --- a/Makefile +++ b/Makefile @@ -3,7 +3,7 @@ EXTVERSION = 0.6.2 MODULE_big = vector DATA = $(wildcard sql/*--*.sql) -OBJS = src/hnsw.o src/hnswbuild.o src/hnswinsert.o src/hnswscan.o src/hnswutils.o src/hnswvacuum.o src/ivfbuild.o src/ivfflat.o src/ivfinsert.o src/ivfkmeans.o src/ivfscan.o src/ivfutils.o src/ivfvacuum.o src/vector.o +OBJS = src/bitvector.o src/hnsw.o src/hnswbuild.o src/hnswinsert.o src/hnswscan.o src/hnswutils.o src/hnswvacuum.o src/ivfbuild.o src/ivfflat.o src/ivfinsert.o src/ivfkmeans.o src/ivfscan.o src/ivfutils.o src/ivfvacuum.o src/vector.o HEADERS = src/vector.h TESTS = $(wildcard test/sql/*.sql) diff --git a/Makefile.win b/Makefile.win index 1bb193e..b3928bc 100644 --- a/Makefile.win +++ b/Makefile.win @@ -1,7 +1,7 @@ EXTENSION = vector EXTVERSION = 0.6.2 -OBJS = src\hnsw.obj src\hnswbuild.obj src\hnswinsert.obj src\hnswscan.obj src\hnswutils.obj src\hnswvacuum.obj src\ivfbuild.obj src\ivfflat.obj src\ivfinsert.obj src\ivfkmeans.obj src\ivfscan.obj src\ivfutils.obj src\ivfvacuum.obj src\vector.obj +OBJS = src\bitvector.obj src\hnsw.obj src\hnswbuild.obj src\hnswinsert.obj src\hnswscan.obj src\hnswutils.obj src\hnswvacuum.obj src\ivfbuild.obj src\ivfflat.obj src\ivfinsert.obj src\ivfkmeans.obj src\ivfscan.obj src\ivfutils.obj src\ivfvacuum.obj src\vector.obj HEADERS = src\vector.h REGRESS = btree cast copy functions input ivfflat_cosine ivfflat_ip ivfflat_l2 ivfflat_options ivfflat_unlogged diff --git a/README.md b/README.md index f9849c8..2f02080 100644 --- a/README.md +++ b/README.md @@ -221,7 +221,19 @@ Cosine distance CREATE INDEX ON items USING hnsw (embedding vector_cosine_ops); ``` -Vectors with up to 2,000 dimensions can be indexed. +Hamming distance - unreleased + +```sql +CREATE INDEX ON items USING hnsw (embedding bit_hamming_ops); +``` + +Jaccard distance - unreleased + +```sql +CREATE INDEX ON items USING hnsw (embedding bit_jaccard_ops); +``` + +Vectors with up to 2,000 dimensions can be indexed, or bit vectors with up to 64,000 dimensions. ### Index Options @@ -699,6 +711,9 @@ Also, note that `NULL` vectors are not indexed (as well as zero vectors for cosi ## Reference +- [Vector](#vector-type) +- [Bit](#bit-type) + ### Vector Type Each vector takes `4 * dimensions + 8` bytes of storage. Each element is a single-precision floating-point number (like the `real` type in Postgres), and all elements must be finite (no `NaN`, `Infinity` or `-Infinity`). Vectors can have up to 16,000 dimensions. @@ -722,6 +737,7 @@ cosine_distance(vector, vector) → double precision | cosine distance | inner_product(vector, vector) → double precision | inner product | l2_distance(vector, vector) → double precision | Euclidean distance | l1_distance(vector, vector) → double precision | taxicab distance | 0.5.0 +quantize_binary(vector) → bit | quantize | unreleased vector_dims(vector) → integer | number of dimensions | vector_norm(vector) → double precision | Euclidean norm | @@ -732,6 +748,24 @@ Function | Description | Added avg(vector) → vector | average | sum(vector) → vector | sum | 0.5.0 +### Bit Type + +Each bit vector takes `dimensions / 8 + (5 or 8)` bytes of storage. See the [Postgres docs](https://www.postgresql.org/docs/current/datatype-bit.html) for more info. + +### Bit Operators + +Operator | Description | Added +--- | --- | --- +<~> | Hamming distance | unreleased +<%> | Jaccard distance | unreleased + +### Bit Functions + +Function | Description | Added +--- | --- | --- +hamming_distance(bit, bit) → double precision | Hamming distance | unreleased +jaccard_distance(bit, bit) → double precision | Jaccard distance | unreleased + ## Installation Notes - Linux and Mac ### Postgres Location diff --git a/sql/vector--0.6.2--0.7.0.sql b/sql/vector--0.6.2--0.7.0.sql new file mode 100644 index 0000000..68409d3 --- /dev/null +++ b/sql/vector--0.6.2--0.7.0.sql @@ -0,0 +1,31 @@ +-- complain if script is sourced in psql, rather than via CREATE EXTENSION +\echo Use "ALTER EXTENSION vector UPDATE TO '0.7.0'" to load this file. \quit + +CREATE FUNCTION quantize_binary(vector) RETURNS bit + AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + +CREATE FUNCTION hamming_distance(bit, bit) RETURNS float8 + AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + +CREATE FUNCTION jaccard_distance(bit, bit) RETURNS float8 + AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + +CREATE OPERATOR <~> ( + LEFTARG = bit, RIGHTARG = bit, PROCEDURE = hamming_distance, + COMMUTATOR = '<~>' +); + +CREATE OPERATOR <%> ( + LEFTARG = bit, RIGHTARG = bit, PROCEDURE = jaccard_distance, + COMMUTATOR = '<%>' +); + +CREATE OPERATOR CLASS bit_hamming_ops + FOR TYPE bit USING hnsw AS + OPERATOR 1 <~> (bit, bit) FOR ORDER BY float_ops, + FUNCTION 1 hamming_distance(bit, bit); + +CREATE OPERATOR CLASS bit_jaccard_ops + FOR TYPE bit USING hnsw AS + OPERATOR 1 <%> (bit, bit) FOR ORDER BY float_ops, + FUNCTION 1 jaccard_distance(bit, bit); diff --git a/sql/vector.sql b/sql/vector.sql index 6a0b2cd..48b91fd 100644 --- a/sql/vector.sql +++ b/sql/vector.sql @@ -58,6 +58,9 @@ CREATE FUNCTION vector_sub(vector, vector) RETURNS vector CREATE FUNCTION vector_mul(vector, vector) RETURNS vector AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; +CREATE FUNCTION quantize_binary(vector) RETURNS bit + AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + -- vector private functions CREATE FUNCTION vector_lt(vector, vector) RETURNS bool @@ -287,3 +290,31 @@ CREATE OPERATOR CLASS vector_cosine_ops OPERATOR 1 <=> (vector, vector) FOR ORDER BY float_ops, FUNCTION 1 vector_negative_inner_product(vector, vector), FUNCTION 2 vector_norm(vector); + +-- bit functions + +CREATE FUNCTION hamming_distance(bit, bit) RETURNS float8 + AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + +CREATE FUNCTION jaccard_distance(bit, bit) RETURNS float8 + AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + +CREATE OPERATOR <~> ( + LEFTARG = bit, RIGHTARG = bit, PROCEDURE = hamming_distance, + COMMUTATOR = '<~>' +); + +CREATE OPERATOR <%> ( + LEFTARG = bit, RIGHTARG = bit, PROCEDURE = jaccard_distance, + COMMUTATOR = '<%>' +); + +CREATE OPERATOR CLASS bit_hamming_ops + FOR TYPE bit USING hnsw AS + OPERATOR 1 <~> (bit, bit) FOR ORDER BY float_ops, + FUNCTION 1 hamming_distance(bit, bit); + +CREATE OPERATOR CLASS bit_jaccard_ops + FOR TYPE bit USING hnsw AS + OPERATOR 1 <%> (bit, bit) FOR ORDER BY float_ops, + FUNCTION 1 jaccard_distance(bit, bit); diff --git a/src/bitvector.c b/src/bitvector.c new file mode 100644 index 0000000..7c5cb0f --- /dev/null +++ b/src/bitvector.c @@ -0,0 +1,90 @@ +#include "postgres.h" + +#include "bitvector.h" +#include "port/pg_bitutils.h" +#include "utils/varbit.h" + +#if PG_VERSION_NUM >= 160000 +#include "varatt.h" +#endif + +/* + * Allocate and initialize a new bit vector + */ +VarBit * +InitBitVector(int dim) +{ + VarBit *result; + int size; + + size = VARBITTOTALLEN(dim); + result = (VarBit *) palloc0(size); + SET_VARSIZE(result, size); + VARBITLEN(result) = dim; + + return result; +} + +/* + * Ensure same number of bits + */ +static inline void +CheckDims(VarBit *a, VarBit *b) +{ + if (VARBITLEN(a) != VARBITLEN(b)) + ereport(ERROR, + (errcode(ERRCODE_DATA_EXCEPTION), + errmsg("different bit lengths %u and %u", VARBITLEN(a), VARBITLEN(b)))); +} + +/* + * Get the Hamming distance between two bit strings + */ +PGDLLEXPORT PG_FUNCTION_INFO_V1(hamming_distance); +Datum +hamming_distance(PG_FUNCTION_ARGS) +{ + VarBit *a = PG_GETARG_VARBIT_P(0); + VarBit *b = PG_GETARG_VARBIT_P(1); + unsigned char *ax = VARBITS(a); + unsigned char *bx = VARBITS(b); + uint64 distance = 0; + + CheckDims(a, b); + + /* TODO Improve performance */ + for (uint32 i = 0; i < VARBITBYTES(a); i++) + distance += pg_number_of_ones[ax[i] ^ bx[i]]; + + PG_RETURN_FLOAT8((double) distance); +} + +/* + * Get the Jaccard distance between two bit strings + */ +PGDLLEXPORT PG_FUNCTION_INFO_V1(jaccard_distance); +Datum +jaccard_distance(PG_FUNCTION_ARGS) +{ + VarBit *a = PG_GETARG_VARBIT_P(0); + VarBit *b = PG_GETARG_VARBIT_P(1); + unsigned char *ax = VARBITS(a); + unsigned char *bx = VARBITS(b); + uint64 ab = 0; + uint64 aa; + uint64 bb; + + CheckDims(a, b); + + /* TODO Improve performance */ + for (uint32 i = 0; i < VARBITBYTES(a); i++) + ab += pg_number_of_ones[ax[i] & bx[i]]; + + if (ab == 0) + PG_RETURN_FLOAT8(1); + + aa = pg_popcount((char *) ax, VARBITBYTES(a)); + bb = pg_popcount((char *) bx, VARBITBYTES(b)); + + PG_RETURN_FLOAT8(1 - (ab / ((double) (aa + bb - ab)))); +} diff --git a/src/bitvector.h b/src/bitvector.h new file mode 100644 index 0000000..b7dec9f --- /dev/null +++ b/src/bitvector.h @@ -0,0 +1,8 @@ +#ifndef BITVECTOR_H +#define BITVECTOR_H + +#include "utils/varbit.h" + +VarBit *InitBitVector(int dim); + +#endif diff --git a/src/hnsw.h b/src/hnsw.h index 901f22b..2e9adbf 100644 --- a/src/hnsw.h +++ b/src/hnsw.h @@ -57,7 +57,8 @@ typedef enum HnswType { - HNSW_TYPE_VECTOR + HNSW_TYPE_VECTOR, + HNSW_TYPE_BIT } HnswType; /* Build phases */ diff --git a/src/hnswbuild.c b/src/hnswbuild.c index cd7150d..971d4ba 100644 --- a/src/hnswbuild.c +++ b/src/hnswbuild.c @@ -44,6 +44,7 @@ #include "access/xact.h" #include "access/xloginsert.h" #include "catalog/index.h" +#include "catalog/pg_type_d.h" #include "commands/progress.h" #include "hnsw.h" #include "miscadmin.h" @@ -665,13 +666,27 @@ HnswSharedMemoryAlloc(Size size, void *state) return chunk; } +/* + * Get max dimensions + */ +static int +GetMaxDimensions(HnswType type) +{ + int maxDimensions = HNSW_MAX_DIM; + + if (type == HNSW_TYPE_BIT) + maxDimensions *= 32; + + return maxDimensions; +} + /* * Initialize the build state */ static void InitBuildState(HnswBuildState * buildstate, Relation heap, Relation index, IndexInfo *indexInfo, ForkNumber forkNum) { - int maxDimensions = HNSW_MAX_DIM; + int maxDimensions; buildstate->heap = heap; buildstate->index = index; @@ -683,6 +698,8 @@ InitBuildState(HnswBuildState * buildstate, Relation heap, Relation index, Index buildstate->efConstruction = HnswGetEfConstruction(index); buildstate->dimensions = TupleDescAttr(index->rd_att, 0)->atttypmod; + maxDimensions = GetMaxDimensions(buildstate->type); + /* Require column to have dimensions to be indexed */ if (buildstate->dimensions < 0) elog(ERROR, "column does not have dimensions"); diff --git a/src/hnswscan.c b/src/hnswscan.c index 8dd4efd..d2016db 100644 --- a/src/hnswscan.c +++ b/src/hnswscan.c @@ -1,6 +1,8 @@ #include "postgres.h" #include "access/relscan.h" +#include "bitvector.h" +#include "catalog/pg_type_d.h" #include "hnsw.h" #include "pgstat.h" #include "storage/bufmgr.h" diff --git a/src/vector.c b/src/vector.c index 2c23701..071aba1 100644 --- a/src/vector.c +++ b/src/vector.c @@ -2,6 +2,7 @@ #include +#include "bitvector.h" #include "catalog/pg_type.h" #include "common/shortest_dec.h" #include "fmgr.h" @@ -858,6 +859,25 @@ vector_mul(PG_FUNCTION_ARGS) PG_RETURN_POINTER(result); } +/* + * Quantize a vector + */ +PGDLLEXPORT PG_FUNCTION_INFO_V1(quantize_binary); +Datum +quantize_binary(PG_FUNCTION_ARGS) +{ + Vector *a = PG_GETARG_VECTOR_P(0); + float *ax = a->x; + VarBit *result = InitBitVector(a->dim); + unsigned char *rx = VARBITS(result); + + for (int i = 0; i < a->dim; i++) + rx[i / 8] |= (ax[i] > 0) << (7 - (i % 8)); + + PG_RETURN_VARBIT_P(result); +} + + /* * Internal helper to compare vectors */ diff --git a/test/expected/bit_functions.out b/test/expected/bit_functions.out new file mode 100644 index 0000000..3647fa2 --- /dev/null +++ b/test/expected/bit_functions.out @@ -0,0 +1,64 @@ +SELECT hamming_distance(B'111', B'111'); + hamming_distance +------------------ + 0 +(1 row) + +SELECT hamming_distance(B'111', B'110'); + hamming_distance +------------------ + 1 +(1 row) + +SELECT hamming_distance(B'111', B'100'); + hamming_distance +------------------ + 2 +(1 row) + +SELECT hamming_distance(B'111', B'000'); + hamming_distance +------------------ + 3 +(1 row) + +SELECT hamming_distance(B'111', B'00'); +ERROR: different bit lengths 3 and 2 +SELECT jaccard_distance(B'1111', B'1111'); + jaccard_distance +------------------ + 0 +(1 row) + +SELECT jaccard_distance(B'1111', B'1110'); + jaccard_distance +------------------ + 0.25 +(1 row) + +SELECT jaccard_distance(B'1111', B'1100'); + jaccard_distance +------------------ + 0.5 +(1 row) + +SELECT jaccard_distance(B'1111', B'1000'); + jaccard_distance +------------------ + 0.75 +(1 row) + +SELECT jaccard_distance(B'1111', B'0000'); + jaccard_distance +------------------ + 1 +(1 row) + +SELECT jaccard_distance(B'1100', B'1000'); + jaccard_distance +------------------ + 0.5 +(1 row) + +SELECT jaccard_distance(B'1111', B'000'); +ERROR: different bit lengths 4 and 3 diff --git a/test/expected/functions.out b/test/expected/functions.out index 12f8f6d..01ce11d 100644 --- a/test/expected/functions.out +++ b/test/expected/functions.out @@ -208,6 +208,18 @@ SELECT l1_distance('[3e38]'::vector, '[-3e38]'); Infinity (1 row) +SELECT quantize_binary('[1,0,-1]'); + quantize_binary +----------------- + 100 +(1 row) + +SELECT quantize_binary('[0,0.1,-0.2,-0.3,0.4,0.5,0.6,-0.7,0.8,-0.9,1]'); + quantize_binary +----------------- + 01001110101 +(1 row) + SELECT avg(v) FROM unnest(ARRAY['[1,2,3]'::vector, '[3,5,7]']) v; avg ----------- diff --git a/test/expected/hnsw_hamming.out b/test/expected/hnsw_hamming.out new file mode 100644 index 0000000..0483281 --- /dev/null +++ b/test/expected/hnsw_hamming.out @@ -0,0 +1,21 @@ +SET enable_seqscan = off; +CREATE TABLE t (val bit(3)); +INSERT INTO t (val) VALUES (B'000'), (B'100'), (B'111'), (NULL); +CREATE INDEX ON t USING hnsw (val bit_hamming_ops); +INSERT INTO t (val) VALUES (B'110'); +SELECT * FROM t ORDER BY val <~> B'111'; + val +----- + 111 + 110 + 100 + 000 +(4 rows) + +SELECT COUNT(*) FROM (SELECT * FROM t ORDER BY val <~> (SELECT NULL::bit)) t2; + count +------- + 4 +(1 row) + +DROP TABLE t; diff --git a/test/expected/hnsw_jaccard.out b/test/expected/hnsw_jaccard.out new file mode 100644 index 0000000..6524f00 --- /dev/null +++ b/test/expected/hnsw_jaccard.out @@ -0,0 +1,21 @@ +SET enable_seqscan = off; +CREATE TABLE t (val bit(4)); +INSERT INTO t (val) VALUES (B'0000'), (B'1100'), (B'1111'), (NULL); +CREATE INDEX ON t USING hnsw (val bit_jaccard_ops); +INSERT INTO t (val) VALUES (B'1110'); +SELECT * FROM t ORDER BY val <%> B'1111'; + val +------ + 1111 + 1110 + 1100 + 0000 +(4 rows) + +SELECT COUNT(*) FROM (SELECT * FROM t ORDER BY val <%> (SELECT NULL::bit)) t2; + count +------- + 4 +(1 row) + +DROP TABLE t; diff --git a/test/sql/bit_functions.sql b/test/sql/bit_functions.sql new file mode 100644 index 0000000..2248338 --- /dev/null +++ b/test/sql/bit_functions.sql @@ -0,0 +1,13 @@ +SELECT hamming_distance(B'111', B'111'); +SELECT hamming_distance(B'111', B'110'); +SELECT hamming_distance(B'111', B'100'); +SELECT hamming_distance(B'111', B'000'); +SELECT hamming_distance(B'111', B'00'); + +SELECT jaccard_distance(B'1111', B'1111'); +SELECT jaccard_distance(B'1111', B'1110'); +SELECT jaccard_distance(B'1111', B'1100'); +SELECT jaccard_distance(B'1111', B'1000'); +SELECT jaccard_distance(B'1111', B'0000'); +SELECT jaccard_distance(B'1100', B'1000'); +SELECT jaccard_distance(B'1111', B'000'); diff --git a/test/sql/functions.sql b/test/sql/functions.sql index 7e820d7..c604be6 100644 --- a/test/sql/functions.sql +++ b/test/sql/functions.sql @@ -48,6 +48,9 @@ SELECT l1_distance('[0,0]'::vector, '[0,1]'); SELECT l1_distance('[1,2]'::vector, '[3]'); SELECT l1_distance('[3e38]'::vector, '[-3e38]'); +SELECT quantize_binary('[1,0,-1]'); +SELECT quantize_binary('[0,0.1,-0.2,-0.3,0.4,0.5,0.6,-0.7,0.8,-0.9,1]'); + SELECT avg(v) FROM unnest(ARRAY['[1,2,3]'::vector, '[3,5,7]']) v; SELECT avg(v) FROM unnest(ARRAY['[1,2,3]'::vector, '[3,5,7]', NULL]) v; SELECT avg(v) FROM unnest(ARRAY[]::vector[]) v; diff --git a/test/sql/hnsw_hamming.sql b/test/sql/hnsw_hamming.sql new file mode 100644 index 0000000..fb21511 --- /dev/null +++ b/test/sql/hnsw_hamming.sql @@ -0,0 +1,12 @@ +SET enable_seqscan = off; + +CREATE TABLE t (val bit(3)); +INSERT INTO t (val) VALUES (B'000'), (B'100'), (B'111'), (NULL); +CREATE INDEX ON t USING hnsw (val bit_hamming_ops); + +INSERT INTO t (val) VALUES (B'110'); + +SELECT * FROM t ORDER BY val <~> B'111'; +SELECT COUNT(*) FROM (SELECT * FROM t ORDER BY val <~> (SELECT NULL::bit)) t2; + +DROP TABLE t; diff --git a/test/sql/hnsw_jaccard.sql b/test/sql/hnsw_jaccard.sql new file mode 100644 index 0000000..ca61c53 --- /dev/null +++ b/test/sql/hnsw_jaccard.sql @@ -0,0 +1,12 @@ +SET enable_seqscan = off; + +CREATE TABLE t (val bit(4)); +INSERT INTO t (val) VALUES (B'0000'), (B'1100'), (B'1111'), (NULL); +CREATE INDEX ON t USING hnsw (val bit_jaccard_ops); + +INSERT INTO t (val) VALUES (B'1110'); + +SELECT * FROM t ORDER BY val <%> B'1111'; +SELECT COUNT(*) FROM (SELECT * FROM t ORDER BY val <%> (SELECT NULL::bit)) t2; + +DROP TABLE t; diff --git a/test/t/020_hnsw_bit_build_recall.pl b/test/t/020_hnsw_bit_build_recall.pl new file mode 100644 index 0000000..09a0258 --- /dev/null +++ b/test/t/020_hnsw_bit_build_recall.pl @@ -0,0 +1,137 @@ +use strict; +use warnings; +use PostgresNode; +use TestLib; +use Test::More; + +my $node; +my @queries = (); +my @expected; +my $limit = 20; +my $dim = 52; +my $max = 2**$dim; + +sub test_recall +{ + my ($min, $operator) = @_; + my $correct = 0; + my $total = 0; + + my $explain = $node->safe_psql("postgres", qq( + SET enable_seqscan = off; + SET hnsw.ef_search = 100; + EXPLAIN ANALYZE SELECT i FROM tst ORDER BY v $operator $queries[0] LIMIT $limit; + )); + like($explain, qr/Index Scan/); + + for my $i (0 .. $#queries) + { + my $actual = $node->safe_psql("postgres", qq( + SET enable_seqscan = off; + SET hnsw.ef_search = 100; + SELECT i FROM tst ORDER BY v $operator $queries[$i] LIMIT $limit; + )); + my @actual_ids = split("\n", $actual); + + my @expected_ids = split("\n", $expected[$i]); + my %expected_set = map { $_ => 1 } @expected_ids; + + foreach (@actual_ids) + { + if (exists($expected_set{$_})) + { + $correct++; + } + } + + $total += $limit; + } + + cmp_ok($correct / $total, ">=", $min, $operator); +} + +# Initialize node +$node = get_new_node('node'); +$node->init; +$node->start; + +# Create table +$node->safe_psql("postgres", "CREATE EXTENSION vector;"); +$node->safe_psql("postgres", "CREATE TABLE tst (i int4, v bit($dim));"); +$node->safe_psql("postgres", + "INSERT INTO tst SELECT i, (random() * $max)::bigint::bit($dim) FROM generate_series(1, 10000) i;" +); + +# Generate queries +for (1 .. 20) +{ + my $r = int(rand() * $max); + push(@queries, "${r}::bigint::bit($dim)"); +} + +# Check each index type +my @operators = ("<~>", "<\%>"); +my @opclasses = ("bit_hamming_ops", "bit_jaccard_ops"); + +for my $i (0 .. $#operators) +{ + my $operator = $operators[$i]; + my $opclass = $opclasses[$i]; + + # Get exact results + @expected = (); + foreach (@queries) + { + # Handle ties + my $res = $node->safe_psql("postgres", qq( + WITH top AS ( + SELECT v $operator $_ AS distance FROM tst ORDER BY v $operator $_ LIMIT $limit + ) + SELECT i FROM tst WHERE (v $operator $_) <= (SELECT MAX(distance) FROM top) + )); + push(@expected, $res); + } + + # Build index serially + $node->safe_psql("postgres", qq( + SET max_parallel_maintenance_workers = 0; + CREATE INDEX idx ON tst USING hnsw (v $opclass); + )); + + # Test approximate results + my $min = $operator eq "<\%>" ? 0.96 : 0.99; + test_recall($min, $operator); + + $node->safe_psql("postgres", "DROP INDEX idx;"); + + # Build index in parallel in memory + my ($ret, $stdout, $stderr) = $node->psql("postgres", qq( + SET client_min_messages = DEBUG; + SET min_parallel_table_scan_size = 1; + CREATE INDEX idx ON tst USING hnsw (v $opclass); + )); + is($ret, 0, $stderr); + like($stderr, qr/using \d+ parallel workers/); + + # Test approximate results + test_recall($min, $operator); + + $node->safe_psql("postgres", "DROP INDEX idx;"); + + # Build index in parallel on disk + # Set parallel_workers on table to use workers with low maintenance_work_mem + ($ret, $stdout, $stderr) = $node->psql("postgres", qq( + ALTER TABLE tst SET (parallel_workers = 2); + SET client_min_messages = DEBUG; + SET maintenance_work_mem = '4MB'; + CREATE INDEX idx ON tst USING hnsw (v $opclass); + ALTER TABLE tst RESET (parallel_workers); + )); + is($ret, 0, $stderr); + like($stderr, qr/using \d+ parallel workers/); + like($stderr, qr/hnsw graph no longer fits into maintenance_work_mem/); + + $node->safe_psql("postgres", "DROP INDEX idx;"); +} + +done_testing();