mirror of
https://github.com/pgvector/pgvector.git
synced 2026-06-06 05:51:21 +08:00
Added support for bit vectors to HNSW
This commit is contained in:
@@ -1,3 +1,10 @@
|
||||
## 0.7.0 (unreleased)
|
||||
|
||||
- Added support for bit vectors to HNSW
|
||||
- Added `hamming_distance` function
|
||||
- Added `jaccard_distance` function
|
||||
- Added `quantize_binary` function
|
||||
|
||||
## 0.6.2 (2024-03-18)
|
||||
|
||||
- Reduced lock contention with parallel HNSW index builds
|
||||
|
||||
2
Makefile
2
Makefile
@@ -3,7 +3,7 @@ EXTVERSION = 0.6.2
|
||||
|
||||
MODULE_big = vector
|
||||
DATA = $(wildcard sql/*--*.sql)
|
||||
OBJS = src/hnsw.o src/hnswbuild.o src/hnswinsert.o src/hnswscan.o src/hnswutils.o src/hnswvacuum.o src/ivfbuild.o src/ivfflat.o src/ivfinsert.o src/ivfkmeans.o src/ivfscan.o src/ivfutils.o src/ivfvacuum.o src/vector.o
|
||||
OBJS = src/bitvector.o src/hnsw.o src/hnswbuild.o src/hnswinsert.o src/hnswscan.o src/hnswutils.o src/hnswvacuum.o src/ivfbuild.o src/ivfflat.o src/ivfinsert.o src/ivfkmeans.o src/ivfscan.o src/ivfutils.o src/ivfvacuum.o src/vector.o
|
||||
HEADERS = src/vector.h
|
||||
|
||||
TESTS = $(wildcard test/sql/*.sql)
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
EXTENSION = vector
|
||||
EXTVERSION = 0.6.2
|
||||
|
||||
OBJS = src\hnsw.obj src\hnswbuild.obj src\hnswinsert.obj src\hnswscan.obj src\hnswutils.obj src\hnswvacuum.obj src\ivfbuild.obj src\ivfflat.obj src\ivfinsert.obj src\ivfkmeans.obj src\ivfscan.obj src\ivfutils.obj src\ivfvacuum.obj src\vector.obj
|
||||
OBJS = src\bitvector.obj src\hnsw.obj src\hnswbuild.obj src\hnswinsert.obj src\hnswscan.obj src\hnswutils.obj src\hnswvacuum.obj src\ivfbuild.obj src\ivfflat.obj src\ivfinsert.obj src\ivfkmeans.obj src\ivfscan.obj src\ivfutils.obj src\ivfvacuum.obj src\vector.obj
|
||||
HEADERS = src\vector.h
|
||||
|
||||
REGRESS = btree cast copy functions input ivfflat_cosine ivfflat_ip ivfflat_l2 ivfflat_options ivfflat_unlogged
|
||||
|
||||
36
README.md
36
README.md
@@ -221,7 +221,19 @@ Cosine distance
|
||||
CREATE INDEX ON items USING hnsw (embedding vector_cosine_ops);
|
||||
```
|
||||
|
||||
Vectors with up to 2,000 dimensions can be indexed.
|
||||
Hamming distance - unreleased
|
||||
|
||||
```sql
|
||||
CREATE INDEX ON items USING hnsw (embedding bit_hamming_ops);
|
||||
```
|
||||
|
||||
Jaccard distance - unreleased
|
||||
|
||||
```sql
|
||||
CREATE INDEX ON items USING hnsw (embedding bit_jaccard_ops);
|
||||
```
|
||||
|
||||
Vectors with up to 2,000 dimensions can be indexed, or bit vectors with up to 64,000 dimensions.
|
||||
|
||||
### Index Options
|
||||
|
||||
@@ -699,6 +711,9 @@ Also, note that `NULL` vectors are not indexed (as well as zero vectors for cosi
|
||||
|
||||
## Reference
|
||||
|
||||
- [Vector](#vector-type)
|
||||
- [Bit](#bit-type)
|
||||
|
||||
### Vector Type
|
||||
|
||||
Each vector takes `4 * dimensions + 8` bytes of storage. Each element is a single-precision floating-point number (like the `real` type in Postgres), and all elements must be finite (no `NaN`, `Infinity` or `-Infinity`). Vectors can have up to 16,000 dimensions.
|
||||
@@ -722,6 +737,7 @@ cosine_distance(vector, vector) → double precision | cosine distance |
|
||||
inner_product(vector, vector) → double precision | inner product |
|
||||
l2_distance(vector, vector) → double precision | Euclidean distance |
|
||||
l1_distance(vector, vector) → double precision | taxicab distance | 0.5.0
|
||||
quantize_binary(vector) → bit | quantize | unreleased
|
||||
vector_dims(vector) → integer | number of dimensions |
|
||||
vector_norm(vector) → double precision | Euclidean norm |
|
||||
|
||||
@@ -732,6 +748,24 @@ Function | Description | Added
|
||||
avg(vector) → vector | average |
|
||||
sum(vector) → vector | sum | 0.5.0
|
||||
|
||||
### Bit Type
|
||||
|
||||
Each bit vector takes `dimensions / 8 + (5 or 8)` bytes of storage. See the [Postgres docs](https://www.postgresql.org/docs/current/datatype-bit.html) for more info.
|
||||
|
||||
### Bit Operators
|
||||
|
||||
Operator | Description | Added
|
||||
--- | --- | ---
|
||||
<~> | Hamming distance | unreleased
|
||||
<%> | Jaccard distance | unreleased
|
||||
|
||||
### Bit Functions
|
||||
|
||||
Function | Description | Added
|
||||
--- | --- | ---
|
||||
hamming_distance(bit, bit) → double precision | Hamming distance | unreleased
|
||||
jaccard_distance(bit, bit) → double precision | Jaccard distance | unreleased
|
||||
|
||||
## Installation Notes - Linux and Mac
|
||||
|
||||
### Postgres Location
|
||||
|
||||
31
sql/vector--0.6.2--0.7.0.sql
Normal file
31
sql/vector--0.6.2--0.7.0.sql
Normal file
@@ -0,0 +1,31 @@
|
||||
-- complain if script is sourced in psql, rather than via CREATE EXTENSION
|
||||
\echo Use "ALTER EXTENSION vector UPDATE TO '0.7.0'" to load this file. \quit
|
||||
|
||||
CREATE FUNCTION quantize_binary(vector) RETURNS bit
|
||||
AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE;
|
||||
|
||||
CREATE FUNCTION hamming_distance(bit, bit) RETURNS float8
|
||||
AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE;
|
||||
|
||||
CREATE FUNCTION jaccard_distance(bit, bit) RETURNS float8
|
||||
AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE;
|
||||
|
||||
CREATE OPERATOR <~> (
|
||||
LEFTARG = bit, RIGHTARG = bit, PROCEDURE = hamming_distance,
|
||||
COMMUTATOR = '<~>'
|
||||
);
|
||||
|
||||
CREATE OPERATOR <%> (
|
||||
LEFTARG = bit, RIGHTARG = bit, PROCEDURE = jaccard_distance,
|
||||
COMMUTATOR = '<%>'
|
||||
);
|
||||
|
||||
CREATE OPERATOR CLASS bit_hamming_ops
|
||||
FOR TYPE bit USING hnsw AS
|
||||
OPERATOR 1 <~> (bit, bit) FOR ORDER BY float_ops,
|
||||
FUNCTION 1 hamming_distance(bit, bit);
|
||||
|
||||
CREATE OPERATOR CLASS bit_jaccard_ops
|
||||
FOR TYPE bit USING hnsw AS
|
||||
OPERATOR 1 <%> (bit, bit) FOR ORDER BY float_ops,
|
||||
FUNCTION 1 jaccard_distance(bit, bit);
|
||||
@@ -58,6 +58,9 @@ CREATE FUNCTION vector_sub(vector, vector) RETURNS vector
|
||||
CREATE FUNCTION vector_mul(vector, vector) RETURNS vector
|
||||
AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE;
|
||||
|
||||
CREATE FUNCTION quantize_binary(vector) RETURNS bit
|
||||
AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE;
|
||||
|
||||
-- vector private functions
|
||||
|
||||
CREATE FUNCTION vector_lt(vector, vector) RETURNS bool
|
||||
@@ -287,3 +290,31 @@ CREATE OPERATOR CLASS vector_cosine_ops
|
||||
OPERATOR 1 <=> (vector, vector) FOR ORDER BY float_ops,
|
||||
FUNCTION 1 vector_negative_inner_product(vector, vector),
|
||||
FUNCTION 2 vector_norm(vector);
|
||||
|
||||
-- bit functions
|
||||
|
||||
CREATE FUNCTION hamming_distance(bit, bit) RETURNS float8
|
||||
AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE;
|
||||
|
||||
CREATE FUNCTION jaccard_distance(bit, bit) RETURNS float8
|
||||
AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE;
|
||||
|
||||
CREATE OPERATOR <~> (
|
||||
LEFTARG = bit, RIGHTARG = bit, PROCEDURE = hamming_distance,
|
||||
COMMUTATOR = '<~>'
|
||||
);
|
||||
|
||||
CREATE OPERATOR <%> (
|
||||
LEFTARG = bit, RIGHTARG = bit, PROCEDURE = jaccard_distance,
|
||||
COMMUTATOR = '<%>'
|
||||
);
|
||||
|
||||
CREATE OPERATOR CLASS bit_hamming_ops
|
||||
FOR TYPE bit USING hnsw AS
|
||||
OPERATOR 1 <~> (bit, bit) FOR ORDER BY float_ops,
|
||||
FUNCTION 1 hamming_distance(bit, bit);
|
||||
|
||||
CREATE OPERATOR CLASS bit_jaccard_ops
|
||||
FOR TYPE bit USING hnsw AS
|
||||
OPERATOR 1 <%> (bit, bit) FOR ORDER BY float_ops,
|
||||
FUNCTION 1 jaccard_distance(bit, bit);
|
||||
|
||||
90
src/bitvector.c
Normal file
90
src/bitvector.c
Normal file
@@ -0,0 +1,90 @@
|
||||
#include "postgres.h"
|
||||
|
||||
#include "bitvector.h"
|
||||
#include "port/pg_bitutils.h"
|
||||
#include "utils/varbit.h"
|
||||
|
||||
#if PG_VERSION_NUM >= 160000
|
||||
#include "varatt.h"
|
||||
#endif
|
||||
|
||||
/*
|
||||
* Allocate and initialize a new bit vector
|
||||
*/
|
||||
VarBit *
|
||||
InitBitVector(int dim)
|
||||
{
|
||||
VarBit *result;
|
||||
int size;
|
||||
|
||||
size = VARBITTOTALLEN(dim);
|
||||
result = (VarBit *) palloc0(size);
|
||||
SET_VARSIZE(result, size);
|
||||
VARBITLEN(result) = dim;
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
/*
|
||||
* Ensure same number of bits
|
||||
*/
|
||||
static inline void
|
||||
CheckDims(VarBit *a, VarBit *b)
|
||||
{
|
||||
if (VARBITLEN(a) != VARBITLEN(b))
|
||||
ereport(ERROR,
|
||||
(errcode(ERRCODE_DATA_EXCEPTION),
|
||||
errmsg("different bit lengths %u and %u", VARBITLEN(a), VARBITLEN(b))));
|
||||
}
|
||||
|
||||
/*
|
||||
* Get the Hamming distance between two bit strings
|
||||
*/
|
||||
PGDLLEXPORT PG_FUNCTION_INFO_V1(hamming_distance);
|
||||
Datum
|
||||
hamming_distance(PG_FUNCTION_ARGS)
|
||||
{
|
||||
VarBit *a = PG_GETARG_VARBIT_P(0);
|
||||
VarBit *b = PG_GETARG_VARBIT_P(1);
|
||||
unsigned char *ax = VARBITS(a);
|
||||
unsigned char *bx = VARBITS(b);
|
||||
uint64 distance = 0;
|
||||
|
||||
CheckDims(a, b);
|
||||
|
||||
/* TODO Improve performance */
|
||||
for (uint32 i = 0; i < VARBITBYTES(a); i++)
|
||||
distance += pg_number_of_ones[ax[i] ^ bx[i]];
|
||||
|
||||
PG_RETURN_FLOAT8((double) distance);
|
||||
}
|
||||
|
||||
/*
|
||||
* Get the Jaccard distance between two bit strings
|
||||
*/
|
||||
PGDLLEXPORT PG_FUNCTION_INFO_V1(jaccard_distance);
|
||||
Datum
|
||||
jaccard_distance(PG_FUNCTION_ARGS)
|
||||
{
|
||||
VarBit *a = PG_GETARG_VARBIT_P(0);
|
||||
VarBit *b = PG_GETARG_VARBIT_P(1);
|
||||
unsigned char *ax = VARBITS(a);
|
||||
unsigned char *bx = VARBITS(b);
|
||||
uint64 ab = 0;
|
||||
uint64 aa;
|
||||
uint64 bb;
|
||||
|
||||
CheckDims(a, b);
|
||||
|
||||
/* TODO Improve performance */
|
||||
for (uint32 i = 0; i < VARBITBYTES(a); i++)
|
||||
ab += pg_number_of_ones[ax[i] & bx[i]];
|
||||
|
||||
if (ab == 0)
|
||||
PG_RETURN_FLOAT8(1);
|
||||
|
||||
aa = pg_popcount((char *) ax, VARBITBYTES(a));
|
||||
bb = pg_popcount((char *) bx, VARBITBYTES(b));
|
||||
|
||||
PG_RETURN_FLOAT8(1 - (ab / ((double) (aa + bb - ab))));
|
||||
}
|
||||
8
src/bitvector.h
Normal file
8
src/bitvector.h
Normal file
@@ -0,0 +1,8 @@
|
||||
#ifndef BITVECTOR_H
|
||||
#define BITVECTOR_H
|
||||
|
||||
#include "utils/varbit.h"
|
||||
|
||||
VarBit *InitBitVector(int dim);
|
||||
|
||||
#endif
|
||||
@@ -57,7 +57,8 @@
|
||||
|
||||
typedef enum HnswType
|
||||
{
|
||||
HNSW_TYPE_VECTOR
|
||||
HNSW_TYPE_VECTOR,
|
||||
HNSW_TYPE_BIT
|
||||
} HnswType;
|
||||
|
||||
/* Build phases */
|
||||
|
||||
@@ -44,6 +44,7 @@
|
||||
#include "access/xact.h"
|
||||
#include "access/xloginsert.h"
|
||||
#include "catalog/index.h"
|
||||
#include "catalog/pg_type_d.h"
|
||||
#include "commands/progress.h"
|
||||
#include "hnsw.h"
|
||||
#include "miscadmin.h"
|
||||
@@ -665,13 +666,27 @@ HnswSharedMemoryAlloc(Size size, void *state)
|
||||
return chunk;
|
||||
}
|
||||
|
||||
/*
|
||||
* Get max dimensions
|
||||
*/
|
||||
static int
|
||||
GetMaxDimensions(HnswType type)
|
||||
{
|
||||
int maxDimensions = HNSW_MAX_DIM;
|
||||
|
||||
if (type == HNSW_TYPE_BIT)
|
||||
maxDimensions *= 32;
|
||||
|
||||
return maxDimensions;
|
||||
}
|
||||
|
||||
/*
|
||||
* Initialize the build state
|
||||
*/
|
||||
static void
|
||||
InitBuildState(HnswBuildState * buildstate, Relation heap, Relation index, IndexInfo *indexInfo, ForkNumber forkNum)
|
||||
{
|
||||
int maxDimensions = HNSW_MAX_DIM;
|
||||
int maxDimensions;
|
||||
|
||||
buildstate->heap = heap;
|
||||
buildstate->index = index;
|
||||
@@ -683,6 +698,8 @@ InitBuildState(HnswBuildState * buildstate, Relation heap, Relation index, Index
|
||||
buildstate->efConstruction = HnswGetEfConstruction(index);
|
||||
buildstate->dimensions = TupleDescAttr(index->rd_att, 0)->atttypmod;
|
||||
|
||||
maxDimensions = GetMaxDimensions(buildstate->type);
|
||||
|
||||
/* Require column to have dimensions to be indexed */
|
||||
if (buildstate->dimensions < 0)
|
||||
elog(ERROR, "column does not have dimensions");
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
#include "postgres.h"
|
||||
|
||||
#include "access/relscan.h"
|
||||
#include "bitvector.h"
|
||||
#include "catalog/pg_type_d.h"
|
||||
#include "hnsw.h"
|
||||
#include "pgstat.h"
|
||||
#include "storage/bufmgr.h"
|
||||
|
||||
20
src/vector.c
20
src/vector.c
@@ -2,6 +2,7 @@
|
||||
|
||||
#include <math.h>
|
||||
|
||||
#include "bitvector.h"
|
||||
#include "catalog/pg_type.h"
|
||||
#include "common/shortest_dec.h"
|
||||
#include "fmgr.h"
|
||||
@@ -858,6 +859,25 @@ vector_mul(PG_FUNCTION_ARGS)
|
||||
PG_RETURN_POINTER(result);
|
||||
}
|
||||
|
||||
/*
|
||||
* Quantize a vector
|
||||
*/
|
||||
PGDLLEXPORT PG_FUNCTION_INFO_V1(quantize_binary);
|
||||
Datum
|
||||
quantize_binary(PG_FUNCTION_ARGS)
|
||||
{
|
||||
Vector *a = PG_GETARG_VECTOR_P(0);
|
||||
float *ax = a->x;
|
||||
VarBit *result = InitBitVector(a->dim);
|
||||
unsigned char *rx = VARBITS(result);
|
||||
|
||||
for (int i = 0; i < a->dim; i++)
|
||||
rx[i / 8] |= (ax[i] > 0) << (7 - (i % 8));
|
||||
|
||||
PG_RETURN_VARBIT_P(result);
|
||||
}
|
||||
|
||||
|
||||
/*
|
||||
* Internal helper to compare vectors
|
||||
*/
|
||||
|
||||
64
test/expected/bit_functions.out
Normal file
64
test/expected/bit_functions.out
Normal file
@@ -0,0 +1,64 @@
|
||||
SELECT hamming_distance(B'111', B'111');
|
||||
hamming_distance
|
||||
------------------
|
||||
0
|
||||
(1 row)
|
||||
|
||||
SELECT hamming_distance(B'111', B'110');
|
||||
hamming_distance
|
||||
------------------
|
||||
1
|
||||
(1 row)
|
||||
|
||||
SELECT hamming_distance(B'111', B'100');
|
||||
hamming_distance
|
||||
------------------
|
||||
2
|
||||
(1 row)
|
||||
|
||||
SELECT hamming_distance(B'111', B'000');
|
||||
hamming_distance
|
||||
------------------
|
||||
3
|
||||
(1 row)
|
||||
|
||||
SELECT hamming_distance(B'111', B'00');
|
||||
ERROR: different bit lengths 3 and 2
|
||||
SELECT jaccard_distance(B'1111', B'1111');
|
||||
jaccard_distance
|
||||
------------------
|
||||
0
|
||||
(1 row)
|
||||
|
||||
SELECT jaccard_distance(B'1111', B'1110');
|
||||
jaccard_distance
|
||||
------------------
|
||||
0.25
|
||||
(1 row)
|
||||
|
||||
SELECT jaccard_distance(B'1111', B'1100');
|
||||
jaccard_distance
|
||||
------------------
|
||||
0.5
|
||||
(1 row)
|
||||
|
||||
SELECT jaccard_distance(B'1111', B'1000');
|
||||
jaccard_distance
|
||||
------------------
|
||||
0.75
|
||||
(1 row)
|
||||
|
||||
SELECT jaccard_distance(B'1111', B'0000');
|
||||
jaccard_distance
|
||||
------------------
|
||||
1
|
||||
(1 row)
|
||||
|
||||
SELECT jaccard_distance(B'1100', B'1000');
|
||||
jaccard_distance
|
||||
------------------
|
||||
0.5
|
||||
(1 row)
|
||||
|
||||
SELECT jaccard_distance(B'1111', B'000');
|
||||
ERROR: different bit lengths 4 and 3
|
||||
@@ -208,6 +208,18 @@ SELECT l1_distance('[3e38]'::vector, '[-3e38]');
|
||||
Infinity
|
||||
(1 row)
|
||||
|
||||
SELECT quantize_binary('[1,0,-1]');
|
||||
quantize_binary
|
||||
-----------------
|
||||
100
|
||||
(1 row)
|
||||
|
||||
SELECT quantize_binary('[0,0.1,-0.2,-0.3,0.4,0.5,0.6,-0.7,0.8,-0.9,1]');
|
||||
quantize_binary
|
||||
-----------------
|
||||
01001110101
|
||||
(1 row)
|
||||
|
||||
SELECT avg(v) FROM unnest(ARRAY['[1,2,3]'::vector, '[3,5,7]']) v;
|
||||
avg
|
||||
-----------
|
||||
|
||||
21
test/expected/hnsw_hamming.out
Normal file
21
test/expected/hnsw_hamming.out
Normal file
@@ -0,0 +1,21 @@
|
||||
SET enable_seqscan = off;
|
||||
CREATE TABLE t (val bit(3));
|
||||
INSERT INTO t (val) VALUES (B'000'), (B'100'), (B'111'), (NULL);
|
||||
CREATE INDEX ON t USING hnsw (val bit_hamming_ops);
|
||||
INSERT INTO t (val) VALUES (B'110');
|
||||
SELECT * FROM t ORDER BY val <~> B'111';
|
||||
val
|
||||
-----
|
||||
111
|
||||
110
|
||||
100
|
||||
000
|
||||
(4 rows)
|
||||
|
||||
SELECT COUNT(*) FROM (SELECT * FROM t ORDER BY val <~> (SELECT NULL::bit)) t2;
|
||||
count
|
||||
-------
|
||||
4
|
||||
(1 row)
|
||||
|
||||
DROP TABLE t;
|
||||
21
test/expected/hnsw_jaccard.out
Normal file
21
test/expected/hnsw_jaccard.out
Normal file
@@ -0,0 +1,21 @@
|
||||
SET enable_seqscan = off;
|
||||
CREATE TABLE t (val bit(4));
|
||||
INSERT INTO t (val) VALUES (B'0000'), (B'1100'), (B'1111'), (NULL);
|
||||
CREATE INDEX ON t USING hnsw (val bit_jaccard_ops);
|
||||
INSERT INTO t (val) VALUES (B'1110');
|
||||
SELECT * FROM t ORDER BY val <%> B'1111';
|
||||
val
|
||||
------
|
||||
1111
|
||||
1110
|
||||
1100
|
||||
0000
|
||||
(4 rows)
|
||||
|
||||
SELECT COUNT(*) FROM (SELECT * FROM t ORDER BY val <%> (SELECT NULL::bit)) t2;
|
||||
count
|
||||
-------
|
||||
4
|
||||
(1 row)
|
||||
|
||||
DROP TABLE t;
|
||||
13
test/sql/bit_functions.sql
Normal file
13
test/sql/bit_functions.sql
Normal file
@@ -0,0 +1,13 @@
|
||||
SELECT hamming_distance(B'111', B'111');
|
||||
SELECT hamming_distance(B'111', B'110');
|
||||
SELECT hamming_distance(B'111', B'100');
|
||||
SELECT hamming_distance(B'111', B'000');
|
||||
SELECT hamming_distance(B'111', B'00');
|
||||
|
||||
SELECT jaccard_distance(B'1111', B'1111');
|
||||
SELECT jaccard_distance(B'1111', B'1110');
|
||||
SELECT jaccard_distance(B'1111', B'1100');
|
||||
SELECT jaccard_distance(B'1111', B'1000');
|
||||
SELECT jaccard_distance(B'1111', B'0000');
|
||||
SELECT jaccard_distance(B'1100', B'1000');
|
||||
SELECT jaccard_distance(B'1111', B'000');
|
||||
@@ -48,6 +48,9 @@ SELECT l1_distance('[0,0]'::vector, '[0,1]');
|
||||
SELECT l1_distance('[1,2]'::vector, '[3]');
|
||||
SELECT l1_distance('[3e38]'::vector, '[-3e38]');
|
||||
|
||||
SELECT quantize_binary('[1,0,-1]');
|
||||
SELECT quantize_binary('[0,0.1,-0.2,-0.3,0.4,0.5,0.6,-0.7,0.8,-0.9,1]');
|
||||
|
||||
SELECT avg(v) FROM unnest(ARRAY['[1,2,3]'::vector, '[3,5,7]']) v;
|
||||
SELECT avg(v) FROM unnest(ARRAY['[1,2,3]'::vector, '[3,5,7]', NULL]) v;
|
||||
SELECT avg(v) FROM unnest(ARRAY[]::vector[]) v;
|
||||
|
||||
12
test/sql/hnsw_hamming.sql
Normal file
12
test/sql/hnsw_hamming.sql
Normal file
@@ -0,0 +1,12 @@
|
||||
SET enable_seqscan = off;
|
||||
|
||||
CREATE TABLE t (val bit(3));
|
||||
INSERT INTO t (val) VALUES (B'000'), (B'100'), (B'111'), (NULL);
|
||||
CREATE INDEX ON t USING hnsw (val bit_hamming_ops);
|
||||
|
||||
INSERT INTO t (val) VALUES (B'110');
|
||||
|
||||
SELECT * FROM t ORDER BY val <~> B'111';
|
||||
SELECT COUNT(*) FROM (SELECT * FROM t ORDER BY val <~> (SELECT NULL::bit)) t2;
|
||||
|
||||
DROP TABLE t;
|
||||
12
test/sql/hnsw_jaccard.sql
Normal file
12
test/sql/hnsw_jaccard.sql
Normal file
@@ -0,0 +1,12 @@
|
||||
SET enable_seqscan = off;
|
||||
|
||||
CREATE TABLE t (val bit(4));
|
||||
INSERT INTO t (val) VALUES (B'0000'), (B'1100'), (B'1111'), (NULL);
|
||||
CREATE INDEX ON t USING hnsw (val bit_jaccard_ops);
|
||||
|
||||
INSERT INTO t (val) VALUES (B'1110');
|
||||
|
||||
SELECT * FROM t ORDER BY val <%> B'1111';
|
||||
SELECT COUNT(*) FROM (SELECT * FROM t ORDER BY val <%> (SELECT NULL::bit)) t2;
|
||||
|
||||
DROP TABLE t;
|
||||
137
test/t/020_hnsw_bit_build_recall.pl
Normal file
137
test/t/020_hnsw_bit_build_recall.pl
Normal file
@@ -0,0 +1,137 @@
|
||||
use strict;
|
||||
use warnings;
|
||||
use PostgresNode;
|
||||
use TestLib;
|
||||
use Test::More;
|
||||
|
||||
my $node;
|
||||
my @queries = ();
|
||||
my @expected;
|
||||
my $limit = 20;
|
||||
my $dim = 52;
|
||||
my $max = 2**$dim;
|
||||
|
||||
sub test_recall
|
||||
{
|
||||
my ($min, $operator) = @_;
|
||||
my $correct = 0;
|
||||
my $total = 0;
|
||||
|
||||
my $explain = $node->safe_psql("postgres", qq(
|
||||
SET enable_seqscan = off;
|
||||
SET hnsw.ef_search = 100;
|
||||
EXPLAIN ANALYZE SELECT i FROM tst ORDER BY v $operator $queries[0] LIMIT $limit;
|
||||
));
|
||||
like($explain, qr/Index Scan/);
|
||||
|
||||
for my $i (0 .. $#queries)
|
||||
{
|
||||
my $actual = $node->safe_psql("postgres", qq(
|
||||
SET enable_seqscan = off;
|
||||
SET hnsw.ef_search = 100;
|
||||
SELECT i FROM tst ORDER BY v $operator $queries[$i] LIMIT $limit;
|
||||
));
|
||||
my @actual_ids = split("\n", $actual);
|
||||
|
||||
my @expected_ids = split("\n", $expected[$i]);
|
||||
my %expected_set = map { $_ => 1 } @expected_ids;
|
||||
|
||||
foreach (@actual_ids)
|
||||
{
|
||||
if (exists($expected_set{$_}))
|
||||
{
|
||||
$correct++;
|
||||
}
|
||||
}
|
||||
|
||||
$total += $limit;
|
||||
}
|
||||
|
||||
cmp_ok($correct / $total, ">=", $min, $operator);
|
||||
}
|
||||
|
||||
# Initialize node
|
||||
$node = get_new_node('node');
|
||||
$node->init;
|
||||
$node->start;
|
||||
|
||||
# Create table
|
||||
$node->safe_psql("postgres", "CREATE EXTENSION vector;");
|
||||
$node->safe_psql("postgres", "CREATE TABLE tst (i int4, v bit($dim));");
|
||||
$node->safe_psql("postgres",
|
||||
"INSERT INTO tst SELECT i, (random() * $max)::bigint::bit($dim) FROM generate_series(1, 10000) i;"
|
||||
);
|
||||
|
||||
# Generate queries
|
||||
for (1 .. 20)
|
||||
{
|
||||
my $r = int(rand() * $max);
|
||||
push(@queries, "${r}::bigint::bit($dim)");
|
||||
}
|
||||
|
||||
# Check each index type
|
||||
my @operators = ("<~>", "<\%>");
|
||||
my @opclasses = ("bit_hamming_ops", "bit_jaccard_ops");
|
||||
|
||||
for my $i (0 .. $#operators)
|
||||
{
|
||||
my $operator = $operators[$i];
|
||||
my $opclass = $opclasses[$i];
|
||||
|
||||
# Get exact results
|
||||
@expected = ();
|
||||
foreach (@queries)
|
||||
{
|
||||
# Handle ties
|
||||
my $res = $node->safe_psql("postgres", qq(
|
||||
WITH top AS (
|
||||
SELECT v $operator $_ AS distance FROM tst ORDER BY v $operator $_ LIMIT $limit
|
||||
)
|
||||
SELECT i FROM tst WHERE (v $operator $_) <= (SELECT MAX(distance) FROM top)
|
||||
));
|
||||
push(@expected, $res);
|
||||
}
|
||||
|
||||
# Build index serially
|
||||
$node->safe_psql("postgres", qq(
|
||||
SET max_parallel_maintenance_workers = 0;
|
||||
CREATE INDEX idx ON tst USING hnsw (v $opclass);
|
||||
));
|
||||
|
||||
# Test approximate results
|
||||
my $min = $operator eq "<\%>" ? 0.96 : 0.99;
|
||||
test_recall($min, $operator);
|
||||
|
||||
$node->safe_psql("postgres", "DROP INDEX idx;");
|
||||
|
||||
# Build index in parallel in memory
|
||||
my ($ret, $stdout, $stderr) = $node->psql("postgres", qq(
|
||||
SET client_min_messages = DEBUG;
|
||||
SET min_parallel_table_scan_size = 1;
|
||||
CREATE INDEX idx ON tst USING hnsw (v $opclass);
|
||||
));
|
||||
is($ret, 0, $stderr);
|
||||
like($stderr, qr/using \d+ parallel workers/);
|
||||
|
||||
# Test approximate results
|
||||
test_recall($min, $operator);
|
||||
|
||||
$node->safe_psql("postgres", "DROP INDEX idx;");
|
||||
|
||||
# Build index in parallel on disk
|
||||
# Set parallel_workers on table to use workers with low maintenance_work_mem
|
||||
($ret, $stdout, $stderr) = $node->psql("postgres", qq(
|
||||
ALTER TABLE tst SET (parallel_workers = 2);
|
||||
SET client_min_messages = DEBUG;
|
||||
SET maintenance_work_mem = '4MB';
|
||||
CREATE INDEX idx ON tst USING hnsw (v $opclass);
|
||||
ALTER TABLE tst RESET (parallel_workers);
|
||||
));
|
||||
is($ret, 0, $stderr);
|
||||
like($stderr, qr/using \d+ parallel workers/);
|
||||
like($stderr, qr/hnsw graph no longer fits into maintenance_work_mem/);
|
||||
|
||||
$node->safe_psql("postgres", "DROP INDEX idx;");
|
||||
}
|
||||
|
||||
done_testing();
|
||||
Reference in New Issue
Block a user