Added support for bit vectors to HNSW

This commit is contained in:
Andrew Kane
2024-04-01 20:30:55 -07:00
parent 7ee9074a9c
commit 94a444f029
21 changed files with 541 additions and 5 deletions

View File

@@ -1,3 +1,10 @@
## 0.7.0 (unreleased)
- Added support for bit vectors to HNSW
- Added `hamming_distance` function
- Added `jaccard_distance` function
- Added `quantize_binary` function
## 0.6.2 (2024-03-18)
- Reduced lock contention with parallel HNSW index builds

View File

@@ -3,7 +3,7 @@ EXTVERSION = 0.6.2
MODULE_big = vector
DATA = $(wildcard sql/*--*.sql)
OBJS = src/hnsw.o src/hnswbuild.o src/hnswinsert.o src/hnswscan.o src/hnswutils.o src/hnswvacuum.o src/ivfbuild.o src/ivfflat.o src/ivfinsert.o src/ivfkmeans.o src/ivfscan.o src/ivfutils.o src/ivfvacuum.o src/vector.o
OBJS = src/bitvector.o src/hnsw.o src/hnswbuild.o src/hnswinsert.o src/hnswscan.o src/hnswutils.o src/hnswvacuum.o src/ivfbuild.o src/ivfflat.o src/ivfinsert.o src/ivfkmeans.o src/ivfscan.o src/ivfutils.o src/ivfvacuum.o src/vector.o
HEADERS = src/vector.h
TESTS = $(wildcard test/sql/*.sql)

View File

@@ -1,7 +1,7 @@
EXTENSION = vector
EXTVERSION = 0.6.2
OBJS = src\hnsw.obj src\hnswbuild.obj src\hnswinsert.obj src\hnswscan.obj src\hnswutils.obj src\hnswvacuum.obj src\ivfbuild.obj src\ivfflat.obj src\ivfinsert.obj src\ivfkmeans.obj src\ivfscan.obj src\ivfutils.obj src\ivfvacuum.obj src\vector.obj
OBJS = src\bitvector.obj src\hnsw.obj src\hnswbuild.obj src\hnswinsert.obj src\hnswscan.obj src\hnswutils.obj src\hnswvacuum.obj src\ivfbuild.obj src\ivfflat.obj src\ivfinsert.obj src\ivfkmeans.obj src\ivfscan.obj src\ivfutils.obj src\ivfvacuum.obj src\vector.obj
HEADERS = src\vector.h
REGRESS = btree cast copy functions input ivfflat_cosine ivfflat_ip ivfflat_l2 ivfflat_options ivfflat_unlogged

View File

@@ -221,7 +221,19 @@ Cosine distance
CREATE INDEX ON items USING hnsw (embedding vector_cosine_ops);
```
Vectors with up to 2,000 dimensions can be indexed.
Hamming distance - unreleased
```sql
CREATE INDEX ON items USING hnsw (embedding bit_hamming_ops);
```
Jaccard distance - unreleased
```sql
CREATE INDEX ON items USING hnsw (embedding bit_jaccard_ops);
```
Vectors with up to 2,000 dimensions can be indexed, or bit vectors with up to 64,000 dimensions.
### Index Options
@@ -699,6 +711,9 @@ Also, note that `NULL` vectors are not indexed (as well as zero vectors for cosi
## Reference
- [Vector](#vector-type)
- [Bit](#bit-type)
### Vector Type
Each vector takes `4 * dimensions + 8` bytes of storage. Each element is a single-precision floating-point number (like the `real` type in Postgres), and all elements must be finite (no `NaN`, `Infinity` or `-Infinity`). Vectors can have up to 16,000 dimensions.
@@ -722,6 +737,7 @@ cosine_distance(vector, vector) → double precision | cosine distance |
inner_product(vector, vector) → double precision | inner product |
l2_distance(vector, vector) → double precision | Euclidean distance |
l1_distance(vector, vector) → double precision | taxicab distance | 0.5.0
quantize_binary(vector) → bit | quantize | unreleased
vector_dims(vector) → integer | number of dimensions |
vector_norm(vector) → double precision | Euclidean norm |
@@ -732,6 +748,24 @@ Function | Description | Added
avg(vector) → vector | average |
sum(vector) → vector | sum | 0.5.0
### Bit Type
Each bit vector takes `dimensions / 8 + (5 or 8)` bytes of storage. See the [Postgres docs](https://www.postgresql.org/docs/current/datatype-bit.html) for more info.
### Bit Operators
Operator | Description | Added
--- | --- | ---
<~> | Hamming distance | unreleased
<%> | Jaccard distance | unreleased
### Bit Functions
Function | Description | Added
--- | --- | ---
hamming_distance(bit, bit) → double precision | Hamming distance | unreleased
jaccard_distance(bit, bit) → double precision | Jaccard distance | unreleased
## Installation Notes - Linux and Mac
### Postgres Location

View File

@@ -0,0 +1,31 @@
-- complain if script is sourced in psql, rather than via CREATE EXTENSION
\echo Use "ALTER EXTENSION vector UPDATE TO '0.7.0'" to load this file. \quit
CREATE FUNCTION quantize_binary(vector) RETURNS bit
AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE;
CREATE FUNCTION hamming_distance(bit, bit) RETURNS float8
AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE;
CREATE FUNCTION jaccard_distance(bit, bit) RETURNS float8
AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE;
CREATE OPERATOR <~> (
LEFTARG = bit, RIGHTARG = bit, PROCEDURE = hamming_distance,
COMMUTATOR = '<~>'
);
CREATE OPERATOR <%> (
LEFTARG = bit, RIGHTARG = bit, PROCEDURE = jaccard_distance,
COMMUTATOR = '<%>'
);
CREATE OPERATOR CLASS bit_hamming_ops
FOR TYPE bit USING hnsw AS
OPERATOR 1 <~> (bit, bit) FOR ORDER BY float_ops,
FUNCTION 1 hamming_distance(bit, bit);
CREATE OPERATOR CLASS bit_jaccard_ops
FOR TYPE bit USING hnsw AS
OPERATOR 1 <%> (bit, bit) FOR ORDER BY float_ops,
FUNCTION 1 jaccard_distance(bit, bit);

View File

@@ -58,6 +58,9 @@ CREATE FUNCTION vector_sub(vector, vector) RETURNS vector
CREATE FUNCTION vector_mul(vector, vector) RETURNS vector
AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE;
CREATE FUNCTION quantize_binary(vector) RETURNS bit
AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE;
-- vector private functions
CREATE FUNCTION vector_lt(vector, vector) RETURNS bool
@@ -287,3 +290,31 @@ CREATE OPERATOR CLASS vector_cosine_ops
OPERATOR 1 <=> (vector, vector) FOR ORDER BY float_ops,
FUNCTION 1 vector_negative_inner_product(vector, vector),
FUNCTION 2 vector_norm(vector);
-- bit functions
CREATE FUNCTION hamming_distance(bit, bit) RETURNS float8
AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE;
CREATE FUNCTION jaccard_distance(bit, bit) RETURNS float8
AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE;
CREATE OPERATOR <~> (
LEFTARG = bit, RIGHTARG = bit, PROCEDURE = hamming_distance,
COMMUTATOR = '<~>'
);
CREATE OPERATOR <%> (
LEFTARG = bit, RIGHTARG = bit, PROCEDURE = jaccard_distance,
COMMUTATOR = '<%>'
);
CREATE OPERATOR CLASS bit_hamming_ops
FOR TYPE bit USING hnsw AS
OPERATOR 1 <~> (bit, bit) FOR ORDER BY float_ops,
FUNCTION 1 hamming_distance(bit, bit);
CREATE OPERATOR CLASS bit_jaccard_ops
FOR TYPE bit USING hnsw AS
OPERATOR 1 <%> (bit, bit) FOR ORDER BY float_ops,
FUNCTION 1 jaccard_distance(bit, bit);

90
src/bitvector.c Normal file
View File

@@ -0,0 +1,90 @@
#include "postgres.h"
#include "bitvector.h"
#include "port/pg_bitutils.h"
#include "utils/varbit.h"
#if PG_VERSION_NUM >= 160000
#include "varatt.h"
#endif
/*
* Allocate and initialize a new bit vector
*/
VarBit *
InitBitVector(int dim)
{
VarBit *result;
int size;
size = VARBITTOTALLEN(dim);
result = (VarBit *) palloc0(size);
SET_VARSIZE(result, size);
VARBITLEN(result) = dim;
return result;
}
/*
* Ensure same number of bits
*/
static inline void
CheckDims(VarBit *a, VarBit *b)
{
if (VARBITLEN(a) != VARBITLEN(b))
ereport(ERROR,
(errcode(ERRCODE_DATA_EXCEPTION),
errmsg("different bit lengths %u and %u", VARBITLEN(a), VARBITLEN(b))));
}
/*
* Get the Hamming distance between two bit strings
*/
PGDLLEXPORT PG_FUNCTION_INFO_V1(hamming_distance);
Datum
hamming_distance(PG_FUNCTION_ARGS)
{
VarBit *a = PG_GETARG_VARBIT_P(0);
VarBit *b = PG_GETARG_VARBIT_P(1);
unsigned char *ax = VARBITS(a);
unsigned char *bx = VARBITS(b);
uint64 distance = 0;
CheckDims(a, b);
/* TODO Improve performance */
for (uint32 i = 0; i < VARBITBYTES(a); i++)
distance += pg_number_of_ones[ax[i] ^ bx[i]];
PG_RETURN_FLOAT8((double) distance);
}
/*
* Get the Jaccard distance between two bit strings
*/
PGDLLEXPORT PG_FUNCTION_INFO_V1(jaccard_distance);
Datum
jaccard_distance(PG_FUNCTION_ARGS)
{
VarBit *a = PG_GETARG_VARBIT_P(0);
VarBit *b = PG_GETARG_VARBIT_P(1);
unsigned char *ax = VARBITS(a);
unsigned char *bx = VARBITS(b);
uint64 ab = 0;
uint64 aa;
uint64 bb;
CheckDims(a, b);
/* TODO Improve performance */
for (uint32 i = 0; i < VARBITBYTES(a); i++)
ab += pg_number_of_ones[ax[i] & bx[i]];
if (ab == 0)
PG_RETURN_FLOAT8(1);
aa = pg_popcount((char *) ax, VARBITBYTES(a));
bb = pg_popcount((char *) bx, VARBITBYTES(b));
PG_RETURN_FLOAT8(1 - (ab / ((double) (aa + bb - ab))));
}

8
src/bitvector.h Normal file
View File

@@ -0,0 +1,8 @@
#ifndef BITVECTOR_H
#define BITVECTOR_H
#include "utils/varbit.h"
VarBit *InitBitVector(int dim);
#endif

View File

@@ -57,7 +57,8 @@
typedef enum HnswType
{
HNSW_TYPE_VECTOR
HNSW_TYPE_VECTOR,
HNSW_TYPE_BIT
} HnswType;
/* Build phases */

View File

@@ -44,6 +44,7 @@
#include "access/xact.h"
#include "access/xloginsert.h"
#include "catalog/index.h"
#include "catalog/pg_type_d.h"
#include "commands/progress.h"
#include "hnsw.h"
#include "miscadmin.h"
@@ -665,13 +666,27 @@ HnswSharedMemoryAlloc(Size size, void *state)
return chunk;
}
/*
* Get max dimensions
*/
static int
GetMaxDimensions(HnswType type)
{
int maxDimensions = HNSW_MAX_DIM;
if (type == HNSW_TYPE_BIT)
maxDimensions *= 32;
return maxDimensions;
}
/*
* Initialize the build state
*/
static void
InitBuildState(HnswBuildState * buildstate, Relation heap, Relation index, IndexInfo *indexInfo, ForkNumber forkNum)
{
int maxDimensions = HNSW_MAX_DIM;
int maxDimensions;
buildstate->heap = heap;
buildstate->index = index;
@@ -683,6 +698,8 @@ InitBuildState(HnswBuildState * buildstate, Relation heap, Relation index, Index
buildstate->efConstruction = HnswGetEfConstruction(index);
buildstate->dimensions = TupleDescAttr(index->rd_att, 0)->atttypmod;
maxDimensions = GetMaxDimensions(buildstate->type);
/* Require column to have dimensions to be indexed */
if (buildstate->dimensions < 0)
elog(ERROR, "column does not have dimensions");

View File

@@ -1,6 +1,8 @@
#include "postgres.h"
#include "access/relscan.h"
#include "bitvector.h"
#include "catalog/pg_type_d.h"
#include "hnsw.h"
#include "pgstat.h"
#include "storage/bufmgr.h"

View File

@@ -2,6 +2,7 @@
#include <math.h>
#include "bitvector.h"
#include "catalog/pg_type.h"
#include "common/shortest_dec.h"
#include "fmgr.h"
@@ -858,6 +859,25 @@ vector_mul(PG_FUNCTION_ARGS)
PG_RETURN_POINTER(result);
}
/*
* Quantize a vector
*/
PGDLLEXPORT PG_FUNCTION_INFO_V1(quantize_binary);
Datum
quantize_binary(PG_FUNCTION_ARGS)
{
Vector *a = PG_GETARG_VECTOR_P(0);
float *ax = a->x;
VarBit *result = InitBitVector(a->dim);
unsigned char *rx = VARBITS(result);
for (int i = 0; i < a->dim; i++)
rx[i / 8] |= (ax[i] > 0) << (7 - (i % 8));
PG_RETURN_VARBIT_P(result);
}
/*
* Internal helper to compare vectors
*/

View File

@@ -0,0 +1,64 @@
SELECT hamming_distance(B'111', B'111');
hamming_distance
------------------
0
(1 row)
SELECT hamming_distance(B'111', B'110');
hamming_distance
------------------
1
(1 row)
SELECT hamming_distance(B'111', B'100');
hamming_distance
------------------
2
(1 row)
SELECT hamming_distance(B'111', B'000');
hamming_distance
------------------
3
(1 row)
SELECT hamming_distance(B'111', B'00');
ERROR: different bit lengths 3 and 2
SELECT jaccard_distance(B'1111', B'1111');
jaccard_distance
------------------
0
(1 row)
SELECT jaccard_distance(B'1111', B'1110');
jaccard_distance
------------------
0.25
(1 row)
SELECT jaccard_distance(B'1111', B'1100');
jaccard_distance
------------------
0.5
(1 row)
SELECT jaccard_distance(B'1111', B'1000');
jaccard_distance
------------------
0.75
(1 row)
SELECT jaccard_distance(B'1111', B'0000');
jaccard_distance
------------------
1
(1 row)
SELECT jaccard_distance(B'1100', B'1000');
jaccard_distance
------------------
0.5
(1 row)
SELECT jaccard_distance(B'1111', B'000');
ERROR: different bit lengths 4 and 3

View File

@@ -208,6 +208,18 @@ SELECT l1_distance('[3e38]'::vector, '[-3e38]');
Infinity
(1 row)
SELECT quantize_binary('[1,0,-1]');
quantize_binary
-----------------
100
(1 row)
SELECT quantize_binary('[0,0.1,-0.2,-0.3,0.4,0.5,0.6,-0.7,0.8,-0.9,1]');
quantize_binary
-----------------
01001110101
(1 row)
SELECT avg(v) FROM unnest(ARRAY['[1,2,3]'::vector, '[3,5,7]']) v;
avg
-----------

View File

@@ -0,0 +1,21 @@
SET enable_seqscan = off;
CREATE TABLE t (val bit(3));
INSERT INTO t (val) VALUES (B'000'), (B'100'), (B'111'), (NULL);
CREATE INDEX ON t USING hnsw (val bit_hamming_ops);
INSERT INTO t (val) VALUES (B'110');
SELECT * FROM t ORDER BY val <~> B'111';
val
-----
111
110
100
000
(4 rows)
SELECT COUNT(*) FROM (SELECT * FROM t ORDER BY val <~> (SELECT NULL::bit)) t2;
count
-------
4
(1 row)
DROP TABLE t;

View File

@@ -0,0 +1,21 @@
SET enable_seqscan = off;
CREATE TABLE t (val bit(4));
INSERT INTO t (val) VALUES (B'0000'), (B'1100'), (B'1111'), (NULL);
CREATE INDEX ON t USING hnsw (val bit_jaccard_ops);
INSERT INTO t (val) VALUES (B'1110');
SELECT * FROM t ORDER BY val <%> B'1111';
val
------
1111
1110
1100
0000
(4 rows)
SELECT COUNT(*) FROM (SELECT * FROM t ORDER BY val <%> (SELECT NULL::bit)) t2;
count
-------
4
(1 row)
DROP TABLE t;

View File

@@ -0,0 +1,13 @@
SELECT hamming_distance(B'111', B'111');
SELECT hamming_distance(B'111', B'110');
SELECT hamming_distance(B'111', B'100');
SELECT hamming_distance(B'111', B'000');
SELECT hamming_distance(B'111', B'00');
SELECT jaccard_distance(B'1111', B'1111');
SELECT jaccard_distance(B'1111', B'1110');
SELECT jaccard_distance(B'1111', B'1100');
SELECT jaccard_distance(B'1111', B'1000');
SELECT jaccard_distance(B'1111', B'0000');
SELECT jaccard_distance(B'1100', B'1000');
SELECT jaccard_distance(B'1111', B'000');

View File

@@ -48,6 +48,9 @@ SELECT l1_distance('[0,0]'::vector, '[0,1]');
SELECT l1_distance('[1,2]'::vector, '[3]');
SELECT l1_distance('[3e38]'::vector, '[-3e38]');
SELECT quantize_binary('[1,0,-1]');
SELECT quantize_binary('[0,0.1,-0.2,-0.3,0.4,0.5,0.6,-0.7,0.8,-0.9,1]');
SELECT avg(v) FROM unnest(ARRAY['[1,2,3]'::vector, '[3,5,7]']) v;
SELECT avg(v) FROM unnest(ARRAY['[1,2,3]'::vector, '[3,5,7]', NULL]) v;
SELECT avg(v) FROM unnest(ARRAY[]::vector[]) v;

12
test/sql/hnsw_hamming.sql Normal file
View File

@@ -0,0 +1,12 @@
SET enable_seqscan = off;
CREATE TABLE t (val bit(3));
INSERT INTO t (val) VALUES (B'000'), (B'100'), (B'111'), (NULL);
CREATE INDEX ON t USING hnsw (val bit_hamming_ops);
INSERT INTO t (val) VALUES (B'110');
SELECT * FROM t ORDER BY val <~> B'111';
SELECT COUNT(*) FROM (SELECT * FROM t ORDER BY val <~> (SELECT NULL::bit)) t2;
DROP TABLE t;

12
test/sql/hnsw_jaccard.sql Normal file
View File

@@ -0,0 +1,12 @@
SET enable_seqscan = off;
CREATE TABLE t (val bit(4));
INSERT INTO t (val) VALUES (B'0000'), (B'1100'), (B'1111'), (NULL);
CREATE INDEX ON t USING hnsw (val bit_jaccard_ops);
INSERT INTO t (val) VALUES (B'1110');
SELECT * FROM t ORDER BY val <%> B'1111';
SELECT COUNT(*) FROM (SELECT * FROM t ORDER BY val <%> (SELECT NULL::bit)) t2;
DROP TABLE t;

View File

@@ -0,0 +1,137 @@
use strict;
use warnings;
use PostgresNode;
use TestLib;
use Test::More;
my $node;
my @queries = ();
my @expected;
my $limit = 20;
my $dim = 52;
my $max = 2**$dim;
sub test_recall
{
my ($min, $operator) = @_;
my $correct = 0;
my $total = 0;
my $explain = $node->safe_psql("postgres", qq(
SET enable_seqscan = off;
SET hnsw.ef_search = 100;
EXPLAIN ANALYZE SELECT i FROM tst ORDER BY v $operator $queries[0] LIMIT $limit;
));
like($explain, qr/Index Scan/);
for my $i (0 .. $#queries)
{
my $actual = $node->safe_psql("postgres", qq(
SET enable_seqscan = off;
SET hnsw.ef_search = 100;
SELECT i FROM tst ORDER BY v $operator $queries[$i] LIMIT $limit;
));
my @actual_ids = split("\n", $actual);
my @expected_ids = split("\n", $expected[$i]);
my %expected_set = map { $_ => 1 } @expected_ids;
foreach (@actual_ids)
{
if (exists($expected_set{$_}))
{
$correct++;
}
}
$total += $limit;
}
cmp_ok($correct / $total, ">=", $min, $operator);
}
# Initialize node
$node = get_new_node('node');
$node->init;
$node->start;
# Create table
$node->safe_psql("postgres", "CREATE EXTENSION vector;");
$node->safe_psql("postgres", "CREATE TABLE tst (i int4, v bit($dim));");
$node->safe_psql("postgres",
"INSERT INTO tst SELECT i, (random() * $max)::bigint::bit($dim) FROM generate_series(1, 10000) i;"
);
# Generate queries
for (1 .. 20)
{
my $r = int(rand() * $max);
push(@queries, "${r}::bigint::bit($dim)");
}
# Check each index type
my @operators = ("<~>", "<\%>");
my @opclasses = ("bit_hamming_ops", "bit_jaccard_ops");
for my $i (0 .. $#operators)
{
my $operator = $operators[$i];
my $opclass = $opclasses[$i];
# Get exact results
@expected = ();
foreach (@queries)
{
# Handle ties
my $res = $node->safe_psql("postgres", qq(
WITH top AS (
SELECT v $operator $_ AS distance FROM tst ORDER BY v $operator $_ LIMIT $limit
)
SELECT i FROM tst WHERE (v $operator $_) <= (SELECT MAX(distance) FROM top)
));
push(@expected, $res);
}
# Build index serially
$node->safe_psql("postgres", qq(
SET max_parallel_maintenance_workers = 0;
CREATE INDEX idx ON tst USING hnsw (v $opclass);
));
# Test approximate results
my $min = $operator eq "<\%>" ? 0.96 : 0.99;
test_recall($min, $operator);
$node->safe_psql("postgres", "DROP INDEX idx;");
# Build index in parallel in memory
my ($ret, $stdout, $stderr) = $node->psql("postgres", qq(
SET client_min_messages = DEBUG;
SET min_parallel_table_scan_size = 1;
CREATE INDEX idx ON tst USING hnsw (v $opclass);
));
is($ret, 0, $stderr);
like($stderr, qr/using \d+ parallel workers/);
# Test approximate results
test_recall($min, $operator);
$node->safe_psql("postgres", "DROP INDEX idx;");
# Build index in parallel on disk
# Set parallel_workers on table to use workers with low maintenance_work_mem
($ret, $stdout, $stderr) = $node->psql("postgres", qq(
ALTER TABLE tst SET (parallel_workers = 2);
SET client_min_messages = DEBUG;
SET maintenance_work_mem = '4MB';
CREATE INDEX idx ON tst USING hnsw (v $opclass);
ALTER TABLE tst RESET (parallel_workers);
));
is($ret, 0, $stderr);
like($stderr, qr/using \d+ parallel workers/);
like($stderr, qr/hnsw graph no longer fits into maintenance_work_mem/);
$node->safe_psql("postgres", "DROP INDEX idx;");
}
done_testing();