From 32a502c838b111078ee8d25995446afb00ba38e6 Mon Sep 17 00:00:00 2001 From: Andrew Kane Date: Tue, 2 Apr 2024 13:55:45 -0700 Subject: [PATCH] Added halfvec type --- CHANGELOG.md | 1 + Makefile | 4 +- Makefile.win | 6 +- README.md | 22 + sql/vector--0.6.2--0.7.0.sql | 126 +++ sql/vector.sql | 142 ++++ src/halfvec.c | 969 ++++++++++++++++++++++++ src/halfvec.h | 43 ++ src/hnsw.h | 1 + src/hnswbuild.c | 4 +- src/hnswscan.c | 2 - src/hnswutils.c | 35 +- src/vector.c | 23 + test/expected/cast.out | 24 + test/expected/copy.out | 18 +- test/expected/halfvec_functions.out | 104 +++ test/expected/halfvec_input.out | 147 ++++ test/expected/hnsw_halfvec_cosine.out | 26 + test/expected/hnsw_halfvec_ip.out | 21 + test/expected/hnsw_halfvec_l2.out | 33 + test/sql/cast.sql | 6 + test/sql/copy.sql | 6 +- test/sql/halfvec_functions.sql | 23 + test/sql/halfvec_input.sql | 34 + test/sql/hnsw_halfvec_cosine.sql | 13 + test/sql/hnsw_halfvec_ip.sql | 12 + test/sql/hnsw_halfvec_l2.sql | 16 + test/t/021_hnsw_halfvec_build_recall.pl | 132 ++++ 28 files changed, 1972 insertions(+), 21 deletions(-) create mode 100644 src/halfvec.c create mode 100644 src/halfvec.h create mode 100644 test/expected/halfvec_functions.out create mode 100644 test/expected/halfvec_input.out create mode 100644 test/expected/hnsw_halfvec_cosine.out create mode 100644 test/expected/hnsw_halfvec_ip.out create mode 100644 test/expected/hnsw_halfvec_l2.out create mode 100644 test/sql/halfvec_functions.sql create mode 100644 test/sql/halfvec_input.sql create mode 100644 test/sql/hnsw_halfvec_cosine.sql create mode 100644 test/sql/hnsw_halfvec_ip.sql create mode 100644 test/sql/hnsw_halfvec_l2.sql create mode 100644 test/t/021_hnsw_halfvec_build_recall.pl diff --git a/CHANGELOG.md b/CHANGELOG.md index 046f082..5bf1395 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,6 @@ ## 0.7.0 (unreleased) +- Added `halfvec` type - Added support for bit vectors to HNSW - Added `hamming_distance` function - Added `jaccard_distance` function diff --git a/Makefile b/Makefile index 04758d0..cab9397 100644 --- a/Makefile +++ b/Makefile @@ -3,8 +3,8 @@ EXTVERSION = 0.6.2 MODULE_big = vector DATA = $(wildcard sql/*--*.sql) -OBJS = src/bitvector.o src/hnsw.o src/hnswbuild.o src/hnswinsert.o src/hnswscan.o src/hnswutils.o src/hnswvacuum.o src/ivfbuild.o src/ivfflat.o src/ivfinsert.o src/ivfkmeans.o src/ivfscan.o src/ivfutils.o src/ivfvacuum.o src/vector.o -HEADERS = src/vector.h +OBJS = src/bitvector.o src/halfvec.o src/hnsw.o src/hnswbuild.o src/hnswinsert.o src/hnswscan.o src/hnswutils.o src/hnswvacuum.o src/ivfbuild.o src/ivfflat.o src/ivfinsert.o src/ivfkmeans.o src/ivfscan.o src/ivfutils.o src/ivfvacuum.o src/vector.o +HEADERS = src/halfvec.h src/vector.h TESTS = $(wildcard test/sql/*.sql) REGRESS = $(patsubst test/sql/%.sql,%,$(TESTS)) diff --git a/Makefile.win b/Makefile.win index cb911b9..04ece60 100644 --- a/Makefile.win +++ b/Makefile.win @@ -1,10 +1,10 @@ EXTENSION = vector EXTVERSION = 0.6.2 -OBJS = src\bitvector.obj src\hnsw.obj src\hnswbuild.obj src\hnswinsert.obj src\hnswscan.obj src\hnswutils.obj src\hnswvacuum.obj src\ivfbuild.obj src\ivfflat.obj src\ivfinsert.obj src\ivfkmeans.obj src\ivfscan.obj src\ivfutils.obj src\ivfvacuum.obj src\vector.obj -HEADERS = src\vector.h +OBJS = src\bitvector.obj src\halfvec.obj src\hnsw.obj src\hnswbuild.obj src\hnswinsert.obj src\hnswscan.obj src\hnswutils.obj src\hnswvacuum.obj src\ivfbuild.obj src\ivfflat.obj src\ivfinsert.obj src\ivfkmeans.obj src\ivfscan.obj src\ivfutils.obj src\ivfvacuum.obj src\vector.obj +HEADERS = src\halfvec.h src\vector.h -REGRESS = bit_functions btree cast copy hnsw_bit_hamming hnsw_bit_jaccard hnsw_options hnsw_unlogged hnsw_vector_cosine hnsw_vector_ip hnsw_vector_l2 ivfflat_options ivfflat_unlogged ivfflat_vector_cosine ivfflat_vector_ip ivfflat_vector_l2 vector_functions vector_input +REGRESS = bit_functions btree cast copy halfvec_functions halfvec_input hnsw_bit_hamming hnsw_bit_jaccard hnsw_halfvec_cosine hnsw_halfvec_ip hnsw_halfvec_l2 hnsw_options hnsw_unlogged hnsw_vector_cosine hnsw_vector_ip hnsw_vector_l2 ivfflat_options ivfflat_unlogged ivfflat_vector_cosine ivfflat_vector_ip ivfflat_vector_l2 vector_functions vector_input REGRESS_OPTS = --inputdir=test --load-extension=$(EXTENSION) # For /arch flags diff --git a/README.md b/README.md index 140f635..cafe953 100644 --- a/README.md +++ b/README.md @@ -712,6 +712,7 @@ Also, note that `NULL` vectors are not indexed (as well as zero vectors for cosi ## Reference - [Vector](#vector-type) +- [Halfvec](#halfvec-type) - [Bit](#bit-type) ### Vector Type @@ -749,6 +750,27 @@ Function | Description | Added avg(vector) → vector | average | sum(vector) → vector | sum | 0.5.0 +### Halfvec Type + +Each half vector takes `2 * dimensions + 8` bytes of storage. Each element is a half-precision floating-point number, and all elements must be finite (no `NaN`, `Infinity` or `-Infinity`). Half vectors can have up to 16,000 dimensions. + +### Halfvec Operators + +Operator | Description | Added +--- | --- | --- +<-> | Euclidean distance | unreleased +<#> | negative inner product | unreleased +<=> | cosine distance | unreleased + +### Halfvec Functions + +Function | Description | Added +--- | --- | --- +cosine_distance(halfvec, halfvec) → double precision | cosine distance | unreleased +inner_product(halfvec, halfvec) → double precision | inner product | unreleased +l2_distance(halfvec, halfvec) → double precision | Euclidean distance | unreleased +l1_distance(halfvec, halfvec) → double precision | taxicab distance | unreleased + ### Bit Type Each bit vector takes `dimensions / 8 + (5 or 8)` bytes of storage. See the [Postgres docs](https://www.postgresql.org/docs/current/datatype-bit.html) for more info. diff --git a/sql/vector--0.6.2--0.7.0.sql b/sql/vector--0.6.2--0.7.0.sql index e3f7c20..dffd83c 100644 --- a/sql/vector--0.6.2--0.7.0.sql +++ b/sql/vector--0.6.2--0.7.0.sql @@ -32,3 +32,129 @@ CREATE OPERATOR CLASS bit_jaccard_ops FOR TYPE bit USING hnsw AS OPERATOR 1 <%> (bit, bit) FOR ORDER BY float_ops, FUNCTION 1 jaccard_distance(bit, bit); + +CREATE TYPE halfvec; + +CREATE FUNCTION halfvec_in(cstring, oid, integer) RETURNS halfvec + AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + +CREATE FUNCTION halfvec_out(halfvec) RETURNS cstring + AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + +CREATE FUNCTION halfvec_typmod_in(cstring[]) RETURNS integer + AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + +CREATE FUNCTION halfvec_recv(internal, oid, integer) RETURNS halfvec + AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + +CREATE FUNCTION halfvec_send(halfvec) RETURNS bytea + AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + +CREATE TYPE halfvec ( + INPUT = halfvec_in, + OUTPUT = halfvec_out, + TYPMOD_IN = halfvec_typmod_in, + RECEIVE = halfvec_recv, + SEND = halfvec_send, + STORAGE = external +); + +CREATE FUNCTION l2_distance(halfvec, halfvec) RETURNS float8 + AS 'MODULE_PATHNAME', 'halfvec_l2_distance' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + +CREATE FUNCTION inner_product(halfvec, halfvec) RETURNS float8 + AS 'MODULE_PATHNAME', 'halfvec_inner_product' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + +CREATE FUNCTION cosine_distance(halfvec, halfvec) RETURNS float8 + AS 'MODULE_PATHNAME', 'halfvec_cosine_distance' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + +CREATE FUNCTION l1_distance(halfvec, halfvec) RETURNS float8 + AS 'MODULE_PATHNAME', 'halfvec_l1_distance' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + +CREATE FUNCTION halfvec_norm(halfvec) RETURNS float8 + AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + +CREATE FUNCTION halfvec_l2_squared_distance(halfvec, halfvec) RETURNS float8 + AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + +CREATE FUNCTION halfvec_negative_inner_product(halfvec, halfvec) RETURNS float8 + AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + +CREATE FUNCTION halfvec(halfvec, integer, boolean) RETURNS halfvec + AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + +CREATE FUNCTION array_to_halfvec(integer[], integer, boolean) RETURNS halfvec + AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + +CREATE FUNCTION array_to_halfvec(real[], integer, boolean) RETURNS halfvec + AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + +CREATE FUNCTION array_to_halfvec(double precision[], integer, boolean) RETURNS halfvec + AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + +CREATE FUNCTION array_to_halfvec(numeric[], integer, boolean) RETURNS halfvec + AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + +CREATE FUNCTION halfvec_to_float4(halfvec, integer, boolean) RETURNS real[] + AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + +CREATE CAST (halfvec AS halfvec) + WITH FUNCTION halfvec(halfvec, integer, boolean) AS IMPLICIT; + +CREATE CAST (halfvec AS real[]) + WITH FUNCTION halfvec_to_float4(halfvec, integer, boolean) AS IMPLICIT; + +CREATE CAST (integer[] AS halfvec) + WITH FUNCTION array_to_halfvec(integer[], integer, boolean) AS ASSIGNMENT; + +CREATE CAST (real[] AS halfvec) + WITH FUNCTION array_to_halfvec(real[], integer, boolean) AS ASSIGNMENT; + +CREATE CAST (double precision[] AS halfvec) + WITH FUNCTION array_to_halfvec(double precision[], integer, boolean) AS ASSIGNMENT; + +CREATE CAST (numeric[] AS halfvec) + WITH FUNCTION array_to_halfvec(numeric[], integer, boolean) AS ASSIGNMENT; + +CREATE OPERATOR <-> ( + LEFTARG = halfvec, RIGHTARG = halfvec, PROCEDURE = l2_distance, + COMMUTATOR = '<->' +); + +CREATE OPERATOR <#> ( + LEFTARG = halfvec, RIGHTARG = halfvec, PROCEDURE = halfvec_negative_inner_product, + COMMUTATOR = '<#>' +); + +CREATE OPERATOR <=> ( + LEFTARG = halfvec, RIGHTARG = halfvec, PROCEDURE = cosine_distance, + COMMUTATOR = '<=>' +); + +CREATE OPERATOR CLASS halfvec_l2_ops + FOR TYPE halfvec USING hnsw AS + OPERATOR 1 <-> (halfvec, halfvec) FOR ORDER BY float_ops, + FUNCTION 1 halfvec_l2_squared_distance(halfvec, halfvec); + +CREATE OPERATOR CLASS halfvec_ip_ops + FOR TYPE halfvec USING hnsw AS + OPERATOR 1 <#> (halfvec, halfvec) FOR ORDER BY float_ops, + FUNCTION 1 halfvec_negative_inner_product(halfvec, halfvec); + +CREATE OPERATOR CLASS halfvec_cosine_ops + FOR TYPE halfvec USING hnsw AS + OPERATOR 1 <=> (halfvec, halfvec) FOR ORDER BY float_ops, + FUNCTION 1 halfvec_negative_inner_product(halfvec, halfvec), + FUNCTION 2 halfvec_norm(halfvec); + +CREATE FUNCTION halfvec_to_vector(halfvec, integer, boolean) RETURNS vector + AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + +CREATE FUNCTION vector_to_halfvec(vector, integer, boolean) RETURNS halfvec + AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + +CREATE CAST (halfvec AS vector) + WITH FUNCTION halfvec_to_vector(halfvec, integer, boolean) AS IMPLICIT; + +CREATE CAST (vector AS halfvec) + WITH FUNCTION vector_to_halfvec(vector, integer, boolean) AS IMPLICIT; diff --git a/sql/vector.sql b/sql/vector.sql index 543ed5a..3fc5081 100644 --- a/sql/vector.sql +++ b/sql/vector.sql @@ -321,3 +321,145 @@ CREATE OPERATOR CLASS bit_jaccard_ops FOR TYPE bit USING hnsw AS OPERATOR 1 <%> (bit, bit) FOR ORDER BY float_ops, FUNCTION 1 jaccard_distance(bit, bit); + +-- halfvec type + +CREATE TYPE halfvec; + +CREATE FUNCTION halfvec_in(cstring, oid, integer) RETURNS halfvec + AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + +CREATE FUNCTION halfvec_out(halfvec) RETURNS cstring + AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + +CREATE FUNCTION halfvec_typmod_in(cstring[]) RETURNS integer + AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + +CREATE FUNCTION halfvec_recv(internal, oid, integer) RETURNS halfvec + AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + +CREATE FUNCTION halfvec_send(halfvec) RETURNS bytea + AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + +CREATE TYPE halfvec ( + INPUT = halfvec_in, + OUTPUT = halfvec_out, + TYPMOD_IN = halfvec_typmod_in, + RECEIVE = halfvec_recv, + SEND = halfvec_send, + STORAGE = external +); + +-- halfvec functions + +CREATE FUNCTION l2_distance(halfvec, halfvec) RETURNS float8 + AS 'MODULE_PATHNAME', 'halfvec_l2_distance' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + +CREATE FUNCTION inner_product(halfvec, halfvec) RETURNS float8 + AS 'MODULE_PATHNAME', 'halfvec_inner_product' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + +CREATE FUNCTION cosine_distance(halfvec, halfvec) RETURNS float8 + AS 'MODULE_PATHNAME', 'halfvec_cosine_distance' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + +CREATE FUNCTION l1_distance(halfvec, halfvec) RETURNS float8 + AS 'MODULE_PATHNAME', 'halfvec_l1_distance' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + +CREATE FUNCTION halfvec_norm(halfvec) RETURNS float8 + AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + +-- halfvec private functions + +CREATE FUNCTION halfvec_l2_squared_distance(halfvec, halfvec) RETURNS float8 + AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + +CREATE FUNCTION halfvec_negative_inner_product(halfvec, halfvec) RETURNS float8 + AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + +-- halfvec cast functions + +CREATE FUNCTION halfvec(halfvec, integer, boolean) RETURNS halfvec + AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + +CREATE FUNCTION array_to_halfvec(integer[], integer, boolean) RETURNS halfvec + AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + +CREATE FUNCTION array_to_halfvec(real[], integer, boolean) RETURNS halfvec + AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + +CREATE FUNCTION array_to_halfvec(double precision[], integer, boolean) RETURNS halfvec + AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + +CREATE FUNCTION array_to_halfvec(numeric[], integer, boolean) RETURNS halfvec + AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + +CREATE FUNCTION halfvec_to_float4(halfvec, integer, boolean) RETURNS real[] + AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + +-- halfvec casts + +CREATE CAST (halfvec AS halfvec) + WITH FUNCTION halfvec(halfvec, integer, boolean) AS IMPLICIT; + +CREATE CAST (halfvec AS real[]) + WITH FUNCTION halfvec_to_float4(halfvec, integer, boolean) AS IMPLICIT; + +CREATE CAST (integer[] AS halfvec) + WITH FUNCTION array_to_halfvec(integer[], integer, boolean) AS ASSIGNMENT; + +CREATE CAST (real[] AS halfvec) + WITH FUNCTION array_to_halfvec(real[], integer, boolean) AS ASSIGNMENT; + +CREATE CAST (double precision[] AS halfvec) + WITH FUNCTION array_to_halfvec(double precision[], integer, boolean) AS ASSIGNMENT; + +CREATE CAST (numeric[] AS halfvec) + WITH FUNCTION array_to_halfvec(numeric[], integer, boolean) AS ASSIGNMENT; + +-- halfvec operators + +CREATE OPERATOR <-> ( + LEFTARG = halfvec, RIGHTARG = halfvec, PROCEDURE = l2_distance, + COMMUTATOR = '<->' +); + +CREATE OPERATOR <#> ( + LEFTARG = halfvec, RIGHTARG = halfvec, PROCEDURE = halfvec_negative_inner_product, + COMMUTATOR = '<#>' +); + +CREATE OPERATOR <=> ( + LEFTARG = halfvec, RIGHTARG = halfvec, PROCEDURE = cosine_distance, + COMMUTATOR = '<=>' +); + +-- halfvec opclasses + +CREATE OPERATOR CLASS halfvec_l2_ops + FOR TYPE halfvec USING hnsw AS + OPERATOR 1 <-> (halfvec, halfvec) FOR ORDER BY float_ops, + FUNCTION 1 halfvec_l2_squared_distance(halfvec, halfvec); + +CREATE OPERATOR CLASS halfvec_ip_ops + FOR TYPE halfvec USING hnsw AS + OPERATOR 1 <#> (halfvec, halfvec) FOR ORDER BY float_ops, + FUNCTION 1 halfvec_negative_inner_product(halfvec, halfvec); + +CREATE OPERATOR CLASS halfvec_cosine_ops + FOR TYPE halfvec USING hnsw AS + OPERATOR 1 <=> (halfvec, halfvec) FOR ORDER BY float_ops, + FUNCTION 1 halfvec_negative_inner_product(halfvec, halfvec), + FUNCTION 2 halfvec_norm(halfvec); + +-- extension casts + +CREATE FUNCTION halfvec_to_vector(halfvec, integer, boolean) RETURNS vector + AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + +CREATE FUNCTION vector_to_halfvec(vector, integer, boolean) RETURNS halfvec + AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + +CREATE CAST (halfvec AS vector) + WITH FUNCTION halfvec_to_vector(halfvec, integer, boolean) AS IMPLICIT; + +CREATE CAST (vector AS halfvec) + WITH FUNCTION vector_to_halfvec(vector, integer, boolean) AS IMPLICIT; diff --git a/src/halfvec.c b/src/halfvec.c new file mode 100644 index 0000000..36c1a85 --- /dev/null +++ b/src/halfvec.c @@ -0,0 +1,969 @@ +#include "postgres.h" + +#include + +#include "catalog/pg_type.h" +#include "common/shortest_dec.h" +#include "fmgr.h" +#include "halfvec.h" +#include "lib/stringinfo.h" +#include "libpq/pqformat.h" +#include "port.h" /* for strtof() */ +#include "utils/array.h" +#include "utils/builtins.h" +#include "utils/float.h" +#include "utils/lsyscache.h" +#include "utils/numeric.h" +#include "vector.h" + +#if PG_VERSION_NUM < 130000 +#define TYPALIGN_DOUBLE 'd' +#define TYPALIGN_INT 'i' +#endif + +/* + * Check if half is NaN + */ +static inline bool +HalfIsNan(half num) +{ +#ifdef FLT16_SUPPORT + return isnan(num); +#else + return (num & 0x7C00) == 0x7C00 && (num & 0x7FFF) != 0x7C00; +#endif +} + +/* + * Check if half is infinite + */ +static inline bool +HalfIsInf(half num) +{ +#ifdef FLT16_SUPPORT + return isinf(num); +#else + return (num & 0x7FFF) == 0x7C00; +#endif +} + +/* + * Check if half is zero + */ +static inline bool +HalfIsZero(half num) +{ +#ifdef FLT16_SUPPORT + return num == 0; +#else + return (num & 0x7FFF) == 0x0000; +#endif +} + +/* + * Get a half from a message buffer + */ +static half +pq_getmsghalf(StringInfo msg) +{ + union + { + half h; + uint16 i; + } swap; + + swap.i = pq_getmsgint(msg, 2); + return swap.h; +} + +/* + * Append a half to a StringInfo buffer + */ +static void +pq_sendhalf(StringInfo buf, half h) +{ + union + { + half h; + uint16 i; + } swap; + + swap.h = h; + pq_sendint16(buf, swap.i); +} + +/* + * Convert a half to a float4 + */ +float +HalfToFloat4(half num) +{ +#ifdef FLT16_SUPPORT + return (float) num; +#else + /* TODO Improve performance */ + + /* Assumes same endianness for floats and integers */ + union + { + float f; + uint32 i; + } swapfloat; + + union + { + half h; + uint16 i; + } swaphalf; + + uint16 bin; + uint32 exponent; + uint32 mantissa; + uint32 result; + + swaphalf.h = num; + bin = swaphalf.i; + exponent = (bin & 0x7C00) >> 10; + mantissa = bin & 0x03FF; + + /* Sign */ + result = (bin & 0x8000) << 16; + + if (exponent == 31) + { + if (mantissa == 0) + { + /* Infinite */ + result |= 0x7F800000; + } + else + { + /* NaN */ + result |= 0x7FC00000; + result |= mantissa << 13; + } + } + else if (exponent == 0) + { + /* Subnormal */ + if (mantissa != 0) + { + exponent = -14; + + for (int i = 0; i < 10; i++) + { + mantissa <<= 1; + exponent -= 1; + + if ((mantissa >> 10) % 2 == 1) + { + mantissa &= 0x03ff; + break; + } + } + + result |= (exponent + 127) << 23; + result |= mantissa << 13; + } + } + else + { + /* Normal */ + result |= (exponent - 15 + 127) << 23; + result |= mantissa << 13; + } + + swapfloat.i = result; + return swapfloat.f; +#endif +} + +/* + * Convert a float4 to a half + */ +half +Float4ToHalfUnchecked(float num) +{ +#ifdef FLT16_SUPPORT + return (_Float16) num; +#else + /* TODO Improve performance */ + + /* Assumes same endianness for floats and integers */ + union + { + float f; + uint32 i; + } swapfloat; + + union + { + half h; + uint16 i; + } swaphalf; + + uint32 bin; + int exponent; + int mantissa; + uint16 result; + + swapfloat.f = num; + bin = swapfloat.i; + exponent = (bin & 0x7F800000) >> 23; + mantissa = bin & 0x007FFFFF; + + /* Sign */ + result = (bin & 0x80000000) >> 16; + + if (isinf(num)) + { + /* Infinite */ + result |= 0x7C00; + } + else if (isnan(num)) + { + /* NaN */ + result |= 0x7E00; + result |= mantissa >> 13; + } + else if (exponent > 98) + { + int m; + int gr; + int s; + + exponent -= 127; + s = mantissa & 0x00000FFF; + + /* Subnormal */ + if (exponent < -14) + { + int diff = -exponent - 14; + + mantissa >>= diff; + mantissa += 1 << (23 - diff); + s |= mantissa & 0x00000FFF; + } + + m = mantissa >> 13; + + /* Round */ + gr = (mantissa >> 12) % 4; + if (gr == 3 || (gr == 1 && s != 0)) + m += 1; + + if (m == 1024) + { + m = 0; + exponent += 1; + } + + if (exponent > 15) + { + /* Infinite */ + result |= 0x7C00; + } + else + { + if (exponent >= -14) + result |= (exponent + 15) << 10; + + result |= m; + } + } + + swaphalf.i = result; + return swaphalf.h; +#endif +} + +/* + * Convert a float4 to a half + */ +half +Float4ToHalf(float num) +{ + half result = Float4ToHalfUnchecked(num); + + if (unlikely(HalfIsInf(result)) && !isinf(num)) + float_overflow_error(); + if (unlikely(HalfIsZero(result)) && num != 0.0) + float_underflow_error(); + + return result; +} + +/* + * Ensure same dimensions + */ +static inline void +CheckDims(HalfVector * a, HalfVector * b) +{ + if (a->dim != b->dim) + ereport(ERROR, + (errcode(ERRCODE_DATA_EXCEPTION), + errmsg("different halfvec dimensions %d and %d", a->dim, b->dim))); +} + +/* + * Ensure expected dimensions + */ +static inline void +CheckExpectedDim(int32 typmod, int dim) +{ + if (typmod != -1 && typmod != dim) + ereport(ERROR, + (errcode(ERRCODE_DATA_EXCEPTION), + errmsg("expected %d dimensions, not %d", typmod, dim))); +} + +/* + * Ensure valid dimensions + */ +static inline void +CheckDim(int dim) +{ + if (dim < 1) + ereport(ERROR, + (errcode(ERRCODE_DATA_EXCEPTION), + errmsg("halfvec must have at least 1 dimension"))); + + if (dim > HALFVEC_MAX_DIM) + ereport(ERROR, + (errcode(ERRCODE_PROGRAM_LIMIT_EXCEEDED), + errmsg("halfvec cannot have more than %d dimensions", HALFVEC_MAX_DIM))); +} + +/* + * Ensure finite element + */ +static inline void +CheckElement(half value) +{ + if (HalfIsNan(value)) + ereport(ERROR, + (errcode(ERRCODE_DATA_EXCEPTION), + errmsg("NaN not allowed in halfvec"))); + + if (HalfIsInf(value)) + ereport(ERROR, + (errcode(ERRCODE_DATA_EXCEPTION), + errmsg("infinite value not allowed in halfvec"))); +} + +/* + * Allocate and initialize a new half vector + */ +HalfVector * +InitHalfVector(int dim) +{ + HalfVector *result; + int size; + + size = HALFVEC_SIZE(dim); + result = (HalfVector *) palloc0(size); + SET_VARSIZE(result, size); + result->dim = dim; + + return result; +} + +/* + * Check for whitespace, since array_isspace() is static + */ +static inline bool +halfvec_isspace(char ch) +{ + if (ch == ' ' || + ch == '\t' || + ch == '\n' || + ch == '\r' || + ch == '\v' || + ch == '\f') + return true; + return false; +} + +#if PG_VERSION_NUM < 120003 +static pg_noinline void +float_overflow_error(void) +{ + ereport(ERROR, + (errcode(ERRCODE_NUMERIC_VALUE_OUT_OF_RANGE), + errmsg("value out of range: overflow"))); +} + +static pg_noinline void +float_underflow_error(void) +{ + ereport(ERROR, + (errcode(ERRCODE_NUMERIC_VALUE_OUT_OF_RANGE), + errmsg("value out of range: underflow"))); +} +#endif + +/* + * Convert textual representation to internal representation + */ +PGDLLEXPORT PG_FUNCTION_INFO_V1(halfvec_in); +Datum +halfvec_in(PG_FUNCTION_ARGS) +{ + char *lit = PG_GETARG_CSTRING(0); + int32 typmod = PG_GETARG_INT32(2); + half x[HALFVEC_MAX_DIM]; + int dim = 0; + char *pt; + char *stringEnd; + HalfVector *result; + char *litcopy = pstrdup(lit); + char *str = litcopy; + + while (halfvec_isspace(*str)) + str++; + + if (*str != '[') + ereport(ERROR, + (errcode(ERRCODE_INVALID_TEXT_REPRESENTATION), + errmsg("malformed halfvec literal: \"%s\"", lit), + errdetail("Vector contents must start with \"[\"."))); + + str++; + pt = strtok(str, ","); + stringEnd = pt; + + while (pt != NULL && *stringEnd != ']') + { + if (dim == HALFVEC_MAX_DIM) + ereport(ERROR, + (errcode(ERRCODE_PROGRAM_LIMIT_EXCEEDED), + errmsg("halfvec cannot have more than %d dimensions", HALFVEC_MAX_DIM))); + + while (halfvec_isspace(*pt)) + pt++; + + /* Check for empty string like float4in */ + if (*pt == '\0') + ereport(ERROR, + (errcode(ERRCODE_INVALID_TEXT_REPRESENTATION), + errmsg("invalid input syntax for type halfvec: \"%s\"", lit))); + + /* Use strtof like float4in to avoid a double-rounding problem */ + x[dim] = Float4ToHalf(strtof(pt, &stringEnd)); + CheckElement(x[dim]); + dim++; + + if (stringEnd == pt) + ereport(ERROR, + (errcode(ERRCODE_INVALID_TEXT_REPRESENTATION), + errmsg("invalid input syntax for type halfvec: \"%s\"", lit))); + + while (halfvec_isspace(*stringEnd)) + stringEnd++; + + if (*stringEnd != '\0' && *stringEnd != ']') + ereport(ERROR, + (errcode(ERRCODE_INVALID_TEXT_REPRESENTATION), + errmsg("invalid input syntax for type halfvec: \"%s\"", lit))); + + pt = strtok(NULL, ","); + } + + if (stringEnd == NULL || *stringEnd != ']') + ereport(ERROR, + (errcode(ERRCODE_INVALID_TEXT_REPRESENTATION), + errmsg("malformed halfvec literal: \"%s\"", lit), + errdetail("Unexpected end of input."))); + + stringEnd++; + + /* Only whitespace is allowed after the closing brace */ + while (halfvec_isspace(*stringEnd)) + stringEnd++; + + if (*stringEnd != '\0') + ereport(ERROR, + (errcode(ERRCODE_INVALID_TEXT_REPRESENTATION), + errmsg("malformed halfvec literal: \"%s\"", lit), + errdetail("Junk after closing right brace."))); + + /* Ensure no consecutive delimiters since strtok skips */ + for (pt = lit + 1; *pt != '\0'; pt++) + { + if (pt[-1] == ',' && *pt == ',') + ereport(ERROR, + (errcode(ERRCODE_INVALID_TEXT_REPRESENTATION), + errmsg("malformed halfvec literal: \"%s\"", lit))); + } + + if (dim < 1) + ereport(ERROR, + (errcode(ERRCODE_DATA_EXCEPTION), + errmsg("halfvec must have at least 1 dimension"))); + + pfree(litcopy); + + CheckExpectedDim(typmod, dim); + + result = InitHalfVector(dim); + for (int i = 0; i < dim; i++) + result->x[i] = x[i]; + + PG_RETURN_POINTER(result); +} + +#define AppendChar(ptr, c) (*(ptr)++ = (c)) +#define AppendFloat(ptr, f) ((ptr) += float_to_shortest_decimal_bufn((f), (ptr))) + +/* + * Convert internal representation to textual representation + */ +PGDLLEXPORT PG_FUNCTION_INFO_V1(halfvec_out); +Datum +halfvec_out(PG_FUNCTION_ARGS) +{ + HalfVector *vector = PG_GETARG_HALFVEC_P(0); + int dim = vector->dim; + char *buf; + char *ptr; + + /* + * Need: + * + * dim * (FLOAT_SHORTEST_DECIMAL_LEN - 1) bytes for + * float_to_shortest_decimal_bufn + * + * dim - 1 bytes for separator + * + * 3 bytes for [, ], and \0 + */ + buf = (char *) palloc(FLOAT_SHORTEST_DECIMAL_LEN * dim + 2); + ptr = buf; + + AppendChar(ptr, '['); + + for (int i = 0; i < dim; i++) + { + if (i > 0) + AppendChar(ptr, ','); + + AppendFloat(ptr, HalfToFloat4(vector->x[i])); + } + + AppendChar(ptr, ']'); + *ptr = '\0'; + + PG_FREE_IF_COPY(vector, 0); + PG_RETURN_CSTRING(buf); +} + +/* + * Convert type modifier + */ +PGDLLEXPORT PG_FUNCTION_INFO_V1(halfvec_typmod_in); +Datum +halfvec_typmod_in(PG_FUNCTION_ARGS) +{ + ArrayType *ta = PG_GETARG_ARRAYTYPE_P(0); + int32 *tl; + int n; + + tl = ArrayGetIntegerTypmods(ta, &n); + + if (n != 1) + ereport(ERROR, + (errcode(ERRCODE_INVALID_PARAMETER_VALUE), + errmsg("invalid type modifier"))); + + if (*tl < 1) + ereport(ERROR, + (errcode(ERRCODE_INVALID_PARAMETER_VALUE), + errmsg("dimensions for type halfvec must be at least 1"))); + + if (*tl > HALFVEC_MAX_DIM) + ereport(ERROR, + (errcode(ERRCODE_INVALID_PARAMETER_VALUE), + errmsg("dimensions for type halfvec cannot exceed %d", HALFVEC_MAX_DIM))); + + PG_RETURN_INT32(*tl); +} + +/* + * Convert external binary representation to internal representation + */ +PGDLLEXPORT PG_FUNCTION_INFO_V1(halfvec_recv); +Datum +halfvec_recv(PG_FUNCTION_ARGS) +{ + StringInfo buf = (StringInfo) PG_GETARG_POINTER(0); + int32 typmod = PG_GETARG_INT32(2); + HalfVector *result; + int16 dim; + int16 unused; + + dim = pq_getmsgint(buf, sizeof(int16)); + unused = pq_getmsgint(buf, sizeof(int16)); + + CheckDim(dim); + CheckExpectedDim(typmod, dim); + + if (unused != 0) + ereport(ERROR, + (errcode(ERRCODE_DATA_EXCEPTION), + errmsg("expected unused to be 0, not %d", unused))); + + result = InitHalfVector(dim); + for (int i = 0; i < dim; i++) + { + result->x[i] = pq_getmsghalf(buf); + CheckElement(result->x[i]); + } + + PG_RETURN_POINTER(result); +} + +/* + * Convert internal representation to the external binary representation + */ +PGDLLEXPORT PG_FUNCTION_INFO_V1(halfvec_send); +Datum +halfvec_send(PG_FUNCTION_ARGS) +{ + HalfVector *vec = PG_GETARG_HALFVEC_P(0); + StringInfoData buf; + + pq_begintypsend(&buf); + pq_sendint(&buf, vec->dim, sizeof(int16)); + pq_sendint(&buf, vec->unused, sizeof(int16)); + for (int i = 0; i < vec->dim; i++) + pq_sendhalf(&buf, vec->x[i]); + + PG_RETURN_BYTEA_P(pq_endtypsend(&buf)); +} + +/* + * Convert half vector to half vector + * This is needed to check the type modifier + */ +PGDLLEXPORT PG_FUNCTION_INFO_V1(halfvec); +Datum +halfvec(PG_FUNCTION_ARGS) +{ + HalfVector *vec = PG_GETARG_HALFVEC_P(0); + int32 typmod = PG_GETARG_INT32(1); + + CheckExpectedDim(typmod, vec->dim); + + PG_RETURN_POINTER(vec); +} + +/* + * Convert array to half vector + */ +PGDLLEXPORT PG_FUNCTION_INFO_V1(array_to_halfvec); +Datum +array_to_halfvec(PG_FUNCTION_ARGS) +{ + ArrayType *array = PG_GETARG_ARRAYTYPE_P(0); + int32 typmod = PG_GETARG_INT32(1); + HalfVector *result; + int16 typlen; + bool typbyval; + char typalign; + Datum *elemsp; + int nelemsp; + + if (ARR_NDIM(array) > 1) + ereport(ERROR, + (errcode(ERRCODE_DATA_EXCEPTION), + errmsg("array must be 1-D"))); + + if (ARR_HASNULL(array) && array_contains_nulls(array)) + ereport(ERROR, + (errcode(ERRCODE_NULL_VALUE_NOT_ALLOWED), + errmsg("array must not contain nulls"))); + + get_typlenbyvalalign(ARR_ELEMTYPE(array), &typlen, &typbyval, &typalign); + deconstruct_array(array, ARR_ELEMTYPE(array), typlen, typbyval, typalign, &elemsp, NULL, &nelemsp); + + CheckDim(nelemsp); + CheckExpectedDim(typmod, nelemsp); + + result = InitHalfVector(nelemsp); + + if (ARR_ELEMTYPE(array) == INT4OID) + { + for (int i = 0; i < nelemsp; i++) + result->x[i] = Float4ToHalf(DatumGetInt32(elemsp[i])); + } + else if (ARR_ELEMTYPE(array) == FLOAT8OID) + { + for (int i = 0; i < nelemsp; i++) + result->x[i] = Float4ToHalf(DatumGetFloat8(elemsp[i])); + } + else if (ARR_ELEMTYPE(array) == FLOAT4OID) + { + for (int i = 0; i < nelemsp; i++) + result->x[i] = Float4ToHalf(DatumGetFloat4(elemsp[i])); + } + else if (ARR_ELEMTYPE(array) == NUMERICOID) + { + for (int i = 0; i < nelemsp; i++) + result->x[i] = Float4ToHalf(DatumGetFloat4(DirectFunctionCall1(numeric_float4, elemsp[i]))); + } + else + { + ereport(ERROR, + (errcode(ERRCODE_DATA_EXCEPTION), + errmsg("unsupported array type"))); + } + + /* + * Free allocation from deconstruct_array. Do not free individual elements + * when pass-by-reference since they point to original array. + */ + pfree(elemsp); + + /* Check elements */ + for (int i = 0; i < result->dim; i++) + CheckElement(result->x[i]); + + PG_RETURN_POINTER(result); +} + +/* + * Convert half vector to float4[] + */ +PGDLLEXPORT PG_FUNCTION_INFO_V1(halfvec_to_float4); +Datum +halfvec_to_float4(PG_FUNCTION_ARGS) +{ + HalfVector *vec = PG_GETARG_HALFVEC_P(0); + Datum *datums; + ArrayType *result; + + datums = (Datum *) palloc(sizeof(Datum) * vec->dim); + + for (int i = 0; i < vec->dim; i++) + datums[i] = Float4GetDatum(HalfToFloat4(vec->x[i])); + + /* Use TYPALIGN_INT for float4 */ + result = construct_array(datums, vec->dim, FLOAT4OID, sizeof(float4), true, TYPALIGN_INT); + + pfree(datums); + + PG_RETURN_POINTER(result); +} + +/* + * Convert vector to half vec + */ +PGDLLEXPORT PG_FUNCTION_INFO_V1(vector_to_halfvec); +Datum +vector_to_halfvec(PG_FUNCTION_ARGS) +{ + Vector *vec = PG_GETARG_VECTOR_P(0); + int32 typmod = PG_GETARG_INT32(1); + HalfVector *result; + + CheckDim(vec->dim); + CheckExpectedDim(typmod, vec->dim); + + result = InitHalfVector(vec->dim); + + for (int i = 0; i < vec->dim; i++) + { + result->x[i] = Float4ToHalfUnchecked(vec->x[i]); + /* TODO Better error for overflow */ + CheckElement(result->x[i]); + } + + PG_RETURN_POINTER(result); +} + +/* + * Get the L2 distance between half vectors + */ +PGDLLEXPORT PG_FUNCTION_INFO_V1(halfvec_l2_distance); +Datum +halfvec_l2_distance(PG_FUNCTION_ARGS) +{ + HalfVector *a = PG_GETARG_HALFVEC_P(0); + HalfVector *b = PG_GETARG_HALFVEC_P(1); + half *ax = a->x; + half *bx = b->x; + float distance = 0.0; + + CheckDims(a, b); + + /* Auto-vectorized */ + for (int i = 0; i < a->dim; i++) + { + float diff = HalfToFloat4(ax[i]) - HalfToFloat4(bx[i]); + + distance += diff * diff; + } + + PG_RETURN_FLOAT8(sqrt((double) distance)); +} + +/* + * Get the L2 squared distance between half vectors + */ +PGDLLEXPORT PG_FUNCTION_INFO_V1(halfvec_l2_squared_distance); +Datum +halfvec_l2_squared_distance(PG_FUNCTION_ARGS) +{ + HalfVector *a = PG_GETARG_HALFVEC_P(0); + HalfVector *b = PG_GETARG_HALFVEC_P(1); + half *ax = a->x; + half *bx = b->x; + float distance = 0.0; + + CheckDims(a, b); + + /* Auto-vectorized */ + for (int i = 0; i < a->dim; i++) + { + float diff = HalfToFloat4(ax[i]) - HalfToFloat4(bx[i]); + + distance += diff * diff; + } + + PG_RETURN_FLOAT8((double) distance); +} + +/* + * Get the inner product of two half vectors + */ +PGDLLEXPORT PG_FUNCTION_INFO_V1(halfvec_inner_product); +Datum +halfvec_inner_product(PG_FUNCTION_ARGS) +{ + HalfVector *a = PG_GETARG_HALFVEC_P(0); + HalfVector *b = PG_GETARG_HALFVEC_P(1); + half *ax = a->x; + half *bx = b->x; + float distance = 0.0; + + CheckDims(a, b); + + /* Auto-vectorized */ + for (int i = 0; i < a->dim; i++) + distance += HalfToFloat4(ax[i]) * HalfToFloat4(bx[i]); + + PG_RETURN_FLOAT8((double) distance); +} + +/* + * Get the negative inner product of two half vectors + */ +PGDLLEXPORT PG_FUNCTION_INFO_V1(halfvec_negative_inner_product); +Datum +halfvec_negative_inner_product(PG_FUNCTION_ARGS) +{ + HalfVector *a = PG_GETARG_HALFVEC_P(0); + HalfVector *b = PG_GETARG_HALFVEC_P(1); + half *ax = a->x; + half *bx = b->x; + float distance = 0.0; + + CheckDims(a, b); + + /* Auto-vectorized */ + for (int i = 0; i < a->dim; i++) + distance += HalfToFloat4(ax[i]) * HalfToFloat4(bx[i]); + + PG_RETURN_FLOAT8((double) distance * -1); +} + +/* + * Get the cosine distance between two half vectors + */ +PGDLLEXPORT PG_FUNCTION_INFO_V1(halfvec_cosine_distance); +Datum +halfvec_cosine_distance(PG_FUNCTION_ARGS) +{ + HalfVector *a = PG_GETARG_HALFVEC_P(0); + HalfVector *b = PG_GETARG_HALFVEC_P(1); + half *ax = a->x; + half *bx = b->x; + float distance = 0.0; + float norma = 0.0; + float normb = 0.0; + double similarity; + + CheckDims(a, b); + + /* Auto-vectorized */ + for (int i = 0; i < a->dim; i++) + { + float axi = HalfToFloat4(ax[i]); + float bxi = HalfToFloat4(bx[i]); + + distance += axi * bxi; + norma += axi * axi; + normb += bxi * bxi; + } + + /* Use sqrt(a * b) over sqrt(a) * sqrt(b) */ + similarity = (double) distance / sqrt((double) norma * (double) normb); + +#ifdef _MSC_VER + /* /fp:fast may not propagate NaN */ + if (isnan(similarity)) + PG_RETURN_FLOAT8(NAN); +#endif + + /* Keep in range */ + if (similarity > 1) + similarity = 1; + else if (similarity < -1) + similarity = -1; + + PG_RETURN_FLOAT8(1 - similarity); +} + +/* + * Get the L1 distance between two half vectors + */ +PGDLLEXPORT PG_FUNCTION_INFO_V1(halfvec_l1_distance); +Datum +halfvec_l1_distance(PG_FUNCTION_ARGS) +{ + HalfVector *a = PG_GETARG_HALFVEC_P(0); + HalfVector *b = PG_GETARG_HALFVEC_P(1); + half *ax = a->x; + half *bx = b->x; + float distance = 0.0; + + CheckDims(a, b); + + /* Auto-vectorized */ + for (int i = 0; i < a->dim; i++) + distance += fabsf(HalfToFloat4(ax[i]) - HalfToFloat4(bx[i])); + + PG_RETURN_FLOAT8((double) distance); +} + +/* + * Get the L2 norm of a half vector + */ +PGDLLEXPORT PG_FUNCTION_INFO_V1(halfvec_norm); +Datum +halfvec_norm(PG_FUNCTION_ARGS) +{ + HalfVector *a = PG_GETARG_HALFVEC_P(0); + half *ax = a->x; + double norm = 0.0; + + /* Auto-vectorized */ + for (int i = 0; i < a->dim; i++) + { + double axi = (double) HalfToFloat4(ax[i]); + + norm += axi * axi; + } + + PG_RETURN_FLOAT8(sqrt(norm)); +} diff --git a/src/halfvec.h b/src/halfvec.h new file mode 100644 index 0000000..4c8adfd --- /dev/null +++ b/src/halfvec.h @@ -0,0 +1,43 @@ +#ifndef HALFVEC_H +#define HALFVEC_H + +#define __STDC_WANT_IEC_60559_TYPES_EXT__ + +#include + +#include "vector.h" + +#ifdef __FLT16_MAX__ +#define FLT16_SUPPORT +#endif + +#ifdef FLT16_SUPPORT +#define half _Float16 +#define HALF_MAX FLT16_MAX +#else +/* TODO #pragma message("")? */ +#define half uint16 +#define HALF_MAX 65504 +#endif + +#define HALFVEC_MAX_DIM VECTOR_MAX_DIM + +#define HALFVEC_SIZE(_dim) (offsetof(HalfVector, x) + sizeof(half)*(_dim)) +#define DatumGetHalfVector(x) ((HalfVector *) PG_DETOAST_DATUM(x)) +#define PG_GETARG_HALFVEC_P(x) DatumGetHalfVector(PG_GETARG_DATUM(x)) +#define PG_RETURN_HALFVEC_P(x) PG_RETURN_POINTER(x) + +typedef struct HalfVector +{ + int32 vl_len_; /* varlena header (do not touch directly!) */ + int16 dim; /* number of dimensions */ + int16 unused; + half x[FLEXIBLE_ARRAY_MEMBER]; +} HalfVector; + +HalfVector *InitHalfVector(int dim); +float HalfToFloat4(half num); +half Float4ToHalf(float num); +half Float4ToHalfUnchecked(float num); + +#endif diff --git a/src/hnsw.h b/src/hnsw.h index 2e9adbf..3012f5f 100644 --- a/src/hnsw.h +++ b/src/hnsw.h @@ -58,6 +58,7 @@ typedef enum HnswType { HNSW_TYPE_VECTOR, + HNSW_TYPE_HALFVEC, HNSW_TYPE_BIT } HnswType; diff --git a/src/hnswbuild.c b/src/hnswbuild.c index 971d4ba..5e586f6 100644 --- a/src/hnswbuild.c +++ b/src/hnswbuild.c @@ -674,7 +674,9 @@ GetMaxDimensions(HnswType type) { int maxDimensions = HNSW_MAX_DIM; - if (type == HNSW_TYPE_BIT) + if (type == HNSW_TYPE_HALFVEC) + maxDimensions *= 2; + else if (type == HNSW_TYPE_BIT) maxDimensions *= 32; return maxDimensions; diff --git a/src/hnswscan.c b/src/hnswscan.c index d2016db..8dd4efd 100644 --- a/src/hnswscan.c +++ b/src/hnswscan.c @@ -1,8 +1,6 @@ #include "postgres.h" #include "access/relscan.h" -#include "bitvector.h" -#include "catalog/pg_type_d.h" #include "hnsw.h" #include "pgstat.h" #include "storage/bufmgr.h" diff --git a/src/hnswutils.c b/src/hnswutils.c index 9bbd398..272934c 100644 --- a/src/hnswutils.c +++ b/src/hnswutils.c @@ -3,13 +3,16 @@ #include #include "access/generic_xlog.h" +#include "catalog/pg_type.h" #include "catalog/pg_type_d.h" +#include "halfvec.h" #include "hnsw.h" #include "lib/pairingheap.h" #include "storage/bufmgr.h" #include "utils/datum.h" #include "utils/memdebug.h" #include "utils/rel.h" +#include "utils/syscache.h" #include "vector.h" #if PG_VERSION_NUM >= 130000 @@ -157,11 +160,28 @@ HnswType HnswGetType(Relation index) { Oid typid = TupleDescAttr(index->rd_att, 0)->atttypid; + HeapTuple tuple; + Form_pg_type type; + int result; if (typid == BITOID || typid == VARBITOID) return HNSW_TYPE_BIT; - return HNSW_TYPE_VECTOR; + tuple = SearchSysCache1(TYPEOID, ObjectIdGetDatum(typid)); + if (!HeapTupleIsValid(tuple)) + elog(ERROR, "cache lookup failed for type %u", typid); + + type = (Form_pg_type) GETSTRUCT(tuple); + if (strcmp(NameStr(type->typname), "vector") == 0) + result = HNSW_TYPE_VECTOR; + else if (strcmp(NameStr(type->typname), "halfvec") == 0) + result = HNSW_TYPE_HALFVEC; + else + elog(ERROR, "Unsupported type"); + + ReleaseSysCache(tuple); + + return result; } /* @@ -190,6 +210,19 @@ HnswNormValue(FmgrInfo *procinfo, Oid collation, Datum *value, HnswType type) *value = PointerGetDatum(result); } + else if (type == HNSW_TYPE_HALFVEC) + { + HalfVector *v = DatumGetHalfVector(*value); + HalfVector *result = InitHalfVector(v->dim); + + for (int i = 0; i < v->dim; i++) + { + /* TODO Fix */ + result->x[i] = Float4ToHalfUnchecked(HalfToFloat4(v->x[i]) / norm); + } + + *value = PointerGetDatum(result); + } else elog(ERROR, "Unsupported type"); diff --git a/src/vector.c b/src/vector.c index e678f51..97d922f 100644 --- a/src/vector.c +++ b/src/vector.c @@ -6,6 +6,7 @@ #include "catalog/pg_type.h" #include "common/shortest_dec.h" #include "fmgr.h" +#include "halfvec.h" #include "hnsw.h" #include "ivfflat.h" #include "lib/stringinfo.h" @@ -531,6 +532,28 @@ vector_to_float4(PG_FUNCTION_ARGS) PG_RETURN_POINTER(result); } +/* + * Convert half vector to vector + */ +PGDLLEXPORT PG_FUNCTION_INFO_V1(halfvec_to_vector); +Datum +halfvec_to_vector(PG_FUNCTION_ARGS) +{ + HalfVector *vec = PG_GETARG_HALFVEC_P(0); + int32 typmod = PG_GETARG_INT32(1); + Vector *result; + + CheckDim(vec->dim); + CheckExpectedDim(typmod, vec->dim); + + result = InitVector(vec->dim); + + for (int i = 0; i < vec->dim; i++) + result->x[i] = HalfToFloat4(vec->x[i]); + + PG_RETURN_POINTER(result); +} + /* * Get the L2 distance between vectors */ diff --git a/test/expected/cast.out b/test/expected/cast.out index 4824261..a5772d4 100644 --- a/test/expected/cast.out +++ b/test/expected/cast.out @@ -46,6 +46,30 @@ SELECT '[1,2,3]'::vector::real[]; {1,2,3} (1 row) +SELECT '[1,2,3]'::vector::halfvec; + halfvec +--------- + [1,2,3] +(1 row) + +SELECT '[1,2,3]'::halfvec::vector; + vector +--------- + [1,2,3] +(1 row) + +SELECT '[1,2,3]'::vector::halfvec(2); +ERROR: expected 2 dimensions, not 3 +SELECT '[1,2,3]'::halfvec::vector(2); +ERROR: expected 2 dimensions, not 3 +SELECT '[65520]'::vector::halfvec; +ERROR: infinite value not allowed in halfvec +SELECT '[1e-8]'::vector::halfvec; + halfvec +--------- + [0] +(1 row) + SELECT array_agg(n)::vector FROM generate_series(1, 16001) n; ERROR: vector cannot have more than 16000 dimensions SELECT array_to_vector(array_agg(n), 16001, false) FROM generate_series(1, 16001) n; diff --git a/test/expected/copy.out b/test/expected/copy.out index 36d4620..79657cc 100644 --- a/test/expected/copy.out +++ b/test/expected/copy.out @@ -1,15 +1,15 @@ -CREATE TABLE t (val vector(3)); -INSERT INTO t (val) VALUES ('[0,0,0]'), ('[1,2,3]'), ('[1,1,1]'), (NULL); -CREATE TABLE t2 (val vector(3)); +CREATE TABLE t (val vector(3), val2 halfvec(3)); +INSERT INTO t (val, val2) VALUES ('[0,0,0]', '[0,0,0]'), ('[1,2,3]', '[1,2,3]'), ('[1,1,1]', '[1,1,1]'), (NULL, NULL); +CREATE TABLE t2 (val vector(3), val2 halfvec(3)); \copy t TO 'results/data.bin' WITH (FORMAT binary) \copy t2 FROM 'results/data.bin' WITH (FORMAT binary) SELECT * FROM t2 ORDER BY val; - val ---------- - [0,0,0] - [1,1,1] - [1,2,3] - + val | val2 +---------+--------- + [0,0,0] | [0,0,0] + [1,1,1] | [1,1,1] + [1,2,3] | [1,2,3] + | (4 rows) DROP TABLE t; diff --git a/test/expected/halfvec_functions.out b/test/expected/halfvec_functions.out new file mode 100644 index 0000000..b0bd832 --- /dev/null +++ b/test/expected/halfvec_functions.out @@ -0,0 +1,104 @@ +SELECT l2_distance('[0,0]'::halfvec, '[3,4]'); + l2_distance +------------- + 5 +(1 row) + +SELECT l2_distance('[0,0]'::halfvec, '[0,1]'); + l2_distance +------------- + 1 +(1 row) + +SELECT l2_distance('[1,2]'::halfvec, '[3]'); +ERROR: different halfvec dimensions 2 and 1 +SELECT '[0,0]'::halfvec <-> '[3,4]'; + ?column? +---------- + 5 +(1 row) + +SELECT inner_product('[1,2]'::halfvec, '[3,4]'); + inner_product +--------------- + 11 +(1 row) + +SELECT inner_product('[1,2]'::halfvec, '[3]'); +ERROR: different halfvec dimensions 2 and 1 +SELECT inner_product('[65504]'::halfvec, '[65504]'); + inner_product +--------------- + 4290774016 +(1 row) + +SELECT '[1,2]'::halfvec <#> '[3,4]'; + ?column? +---------- + -11 +(1 row) + +SELECT cosine_distance('[1,2]'::halfvec, '[2,4]'); + cosine_distance +----------------- + 0 +(1 row) + +SELECT cosine_distance('[1,2]'::halfvec, '[0,0]'); + cosine_distance +----------------- + NaN +(1 row) + +SELECT cosine_distance('[1,1]'::halfvec, '[1,1]'); + cosine_distance +----------------- + 0 +(1 row) + +SELECT cosine_distance('[1,0]'::halfvec, '[0,2]'); + cosine_distance +----------------- + 1 +(1 row) + +SELECT cosine_distance('[1,1]'::halfvec, '[-1,-1]'); + cosine_distance +----------------- + 2 +(1 row) + +SELECT cosine_distance('[1,2]'::halfvec, '[3]'); +ERROR: different halfvec dimensions 2 and 1 +SELECT cosine_distance('[1,1]'::halfvec, '[1.1,1.1]'); + cosine_distance +----------------- + 0 +(1 row) + +SELECT cosine_distance('[1,1]'::halfvec, '[-1.1,-1.1]'); + cosine_distance +----------------- + 2 +(1 row) + +SELECT '[1,2]'::halfvec <=> '[2,4]'; + ?column? +---------- + 0 +(1 row) + +SELECT l1_distance('[0,0]'::halfvec, '[3,4]'); + l1_distance +------------- + 7 +(1 row) + +SELECT l1_distance('[0,0]'::halfvec, '[0,1]'); + l1_distance +------------- + 1 +(1 row) + +SELECT l1_distance('[1,2]'::halfvec, '[3]'); +ERROR: different halfvec dimensions 2 and 1 diff --git a/test/expected/halfvec_input.out b/test/expected/halfvec_input.out new file mode 100644 index 0000000..169bd6c --- /dev/null +++ b/test/expected/halfvec_input.out @@ -0,0 +1,147 @@ +SELECT '[1,2,3]'::halfvec; + halfvec +--------- + [1,2,3] +(1 row) + +SELECT '[-1,-2,-3]'::halfvec; + halfvec +------------ + [-1,-2,-3] +(1 row) + +SELECT '[1.,2.,3.]'::halfvec; + halfvec +--------- + [1,2,3] +(1 row) + +SELECT ' [ 1, 2 , 3 ] '::halfvec; + halfvec +--------- + [1,2,3] +(1 row) + +SELECT '[1.23456]'::halfvec; + halfvec +------------ + [1.234375] +(1 row) + +SELECT '[hello,1]'::halfvec; +ERROR: invalid input syntax for type halfvec: "[hello,1]" +LINE 1: SELECT '[hello,1]'::halfvec; + ^ +SELECT '[NaN,1]'::halfvec; +ERROR: NaN not allowed in halfvec +LINE 1: SELECT '[NaN,1]'::halfvec; + ^ +SELECT '[Infinity,1]'::halfvec; +ERROR: infinite value not allowed in halfvec +LINE 1: SELECT '[Infinity,1]'::halfvec; + ^ +SELECT '[-Infinity,1]'::halfvec; +ERROR: infinite value not allowed in halfvec +LINE 1: SELECT '[-Infinity,1]'::halfvec; + ^ +SELECT '[65519,-65519]'::halfvec; + halfvec +---------------- + [65504,-65504] +(1 row) + +SELECT '[65520,-65520]'::halfvec; +ERROR: value out of range: overflow +LINE 1: SELECT '[65520,-65520]'::halfvec; + ^ +SELECT '[1e-8,-1e-8]'::halfvec; +ERROR: value out of range: underflow +LINE 1: SELECT '[1e-8,-1e-8]'::halfvec; + ^ +SELECT '[4e38,1]'::halfvec; +ERROR: infinite value not allowed in halfvec +LINE 1: SELECT '[4e38,1]'::halfvec; + ^ +SELECT '[1,2,3'::halfvec; +ERROR: malformed halfvec literal: "[1,2,3" +LINE 1: SELECT '[1,2,3'::halfvec; + ^ +DETAIL: Unexpected end of input. +SELECT '[1,2,3]9'::halfvec; +ERROR: malformed halfvec literal: "[1,2,3]9" +LINE 1: SELECT '[1,2,3]9'::halfvec; + ^ +DETAIL: Junk after closing right brace. +SELECT '1,2,3'::halfvec; +ERROR: malformed halfvec literal: "1,2,3" +LINE 1: SELECT '1,2,3'::halfvec; + ^ +DETAIL: Vector contents must start with "[". +SELECT ''::halfvec; +ERROR: malformed halfvec literal: "" +LINE 1: SELECT ''::halfvec; + ^ +DETAIL: Vector contents must start with "[". +SELECT '['::halfvec; +ERROR: malformed halfvec literal: "[" +LINE 1: SELECT '['::halfvec; + ^ +DETAIL: Unexpected end of input. +SELECT '[,'::halfvec; +ERROR: malformed halfvec literal: "[," +LINE 1: SELECT '[,'::halfvec; + ^ +DETAIL: Unexpected end of input. +SELECT '[]'::halfvec; +ERROR: halfvec must have at least 1 dimension +LINE 1: SELECT '[]'::halfvec; + ^ +SELECT '[1,]'::halfvec; +ERROR: invalid input syntax for type halfvec: "[1,]" +LINE 1: SELECT '[1,]'::halfvec; + ^ +SELECT '[1a]'::halfvec; +ERROR: invalid input syntax for type halfvec: "[1a]" +LINE 1: SELECT '[1a]'::halfvec; + ^ +SELECT '[1,,3]'::halfvec; +ERROR: malformed halfvec literal: "[1,,3]" +LINE 1: SELECT '[1,,3]'::halfvec; + ^ +SELECT '[1, ,3]'::halfvec; +ERROR: invalid input syntax for type halfvec: "[1, ,3]" +LINE 1: SELECT '[1, ,3]'::halfvec; + ^ +SELECT '[1,2,3]'::halfvec(3); + halfvec +--------- + [1,2,3] +(1 row) + +SELECT '[1,2,3]'::halfvec(2); +ERROR: expected 2 dimensions, not 3 +SELECT '[1,2,3]'::halfvec(3, 2); +ERROR: invalid type modifier +LINE 1: SELECT '[1,2,3]'::halfvec(3, 2); + ^ +SELECT '[1,2,3]'::halfvec('a'); +ERROR: invalid input syntax for type integer: "a" +LINE 1: SELECT '[1,2,3]'::halfvec('a'); + ^ +SELECT '[1,2,3]'::halfvec(0); +ERROR: dimensions for type halfvec must be at least 1 +LINE 1: SELECT '[1,2,3]'::halfvec(0); + ^ +SELECT '[1,2,3]'::halfvec(16001); +ERROR: dimensions for type halfvec cannot exceed 16000 +LINE 1: SELECT '[1,2,3]'::halfvec(16001); + ^ +SELECT unnest('{"[1,2,3]", "[4,5,6]"}'::halfvec[]); + unnest +--------- + [1,2,3] + [4,5,6] +(2 rows) + +SELECT '{"[1,2,3]"}'::halfvec(2)[]; +ERROR: expected 2 dimensions, not 3 diff --git a/test/expected/hnsw_halfvec_cosine.out b/test/expected/hnsw_halfvec_cosine.out new file mode 100644 index 0000000..6ccca49 --- /dev/null +++ b/test/expected/hnsw_halfvec_cosine.out @@ -0,0 +1,26 @@ +SET enable_seqscan = off; +CREATE TABLE t (val halfvec(3)); +INSERT INTO t (val) VALUES ('[0,0,0]'), ('[1,2,3]'), ('[1,1,1]'), (NULL); +CREATE INDEX ON t USING hnsw (val halfvec_cosine_ops); +INSERT INTO t (val) VALUES ('[1,2,4]'); +SELECT * FROM t ORDER BY val <=> '[3,3,3]'; + val +--------- + [1,1,1] + [1,2,3] + [1,2,4] +(3 rows) + +SELECT COUNT(*) FROM (SELECT * FROM t ORDER BY val <=> '[0,0,0]') t2; + count +------- + 3 +(1 row) + +SELECT COUNT(*) FROM (SELECT * FROM t ORDER BY val <=> (SELECT NULL::halfvec)) t2; + count +------- + 3 +(1 row) + +DROP TABLE t; diff --git a/test/expected/hnsw_halfvec_ip.out b/test/expected/hnsw_halfvec_ip.out new file mode 100644 index 0000000..7c004c9 --- /dev/null +++ b/test/expected/hnsw_halfvec_ip.out @@ -0,0 +1,21 @@ +SET enable_seqscan = off; +CREATE TABLE t (val halfvec(3)); +INSERT INTO t (val) VALUES ('[0,0,0]'), ('[1,2,3]'), ('[1,1,1]'), (NULL); +CREATE INDEX ON t USING hnsw (val halfvec_ip_ops); +INSERT INTO t (val) VALUES ('[1,2,4]'); +SELECT * FROM t ORDER BY val <#> '[3,3,3]'; + val +--------- + [1,2,4] + [1,2,3] + [1,1,1] + [0,0,0] +(4 rows) + +SELECT COUNT(*) FROM (SELECT * FROM t ORDER BY val <#> (SELECT NULL::halfvec)) t2; + count +------- + 4 +(1 row) + +DROP TABLE t; diff --git a/test/expected/hnsw_halfvec_l2.out b/test/expected/hnsw_halfvec_l2.out new file mode 100644 index 0000000..ecfe77c --- /dev/null +++ b/test/expected/hnsw_halfvec_l2.out @@ -0,0 +1,33 @@ +SET enable_seqscan = off; +CREATE TABLE t (val halfvec(3)); +INSERT INTO t (val) VALUES ('[0,0,0]'), ('[1,2,3]'), ('[1,1,1]'), (NULL); +CREATE INDEX ON t USING hnsw (val halfvec_l2_ops); +INSERT INTO t (val) VALUES ('[1,2,4]'); +SELECT * FROM t ORDER BY val <-> '[3,3,3]'; + val +--------- + [1,2,3] + [1,2,4] + [1,1,1] + [0,0,0] +(4 rows) + +SELECT COUNT(*) FROM (SELECT * FROM t ORDER BY val <-> (SELECT NULL::halfvec)) t2; + count +------- + 4 +(1 row) + +SELECT COUNT(*) FROM t; + count +------- + 5 +(1 row) + +TRUNCATE t; +SELECT * FROM t ORDER BY val <-> '[3,3,3]'; + val +----- +(0 rows) + +DROP TABLE t; diff --git a/test/sql/cast.sql b/test/sql/cast.sql index c73ab07..2a43671 100644 --- a/test/sql/cast.sql +++ b/test/sql/cast.sql @@ -10,6 +10,12 @@ SELECT '{-Infinity}'::real[]::vector; SELECT '{}'::real[]::vector; SELECT '{{1}}'::real[]::vector; SELECT '[1,2,3]'::vector::real[]; +SELECT '[1,2,3]'::vector::halfvec; +SELECT '[1,2,3]'::halfvec::vector; +SELECT '[1,2,3]'::vector::halfvec(2); +SELECT '[1,2,3]'::halfvec::vector(2); +SELECT '[65520]'::vector::halfvec; +SELECT '[1e-8]'::vector::halfvec; SELECT array_agg(n)::vector FROM generate_series(1, 16001) n; SELECT array_to_vector(array_agg(n), 16001, false) FROM generate_series(1, 16001) n; diff --git a/test/sql/copy.sql b/test/sql/copy.sql index 2820090..ee93a50 100644 --- a/test/sql/copy.sql +++ b/test/sql/copy.sql @@ -1,7 +1,7 @@ -CREATE TABLE t (val vector(3)); -INSERT INTO t (val) VALUES ('[0,0,0]'), ('[1,2,3]'), ('[1,1,1]'), (NULL); +CREATE TABLE t (val vector(3), val2 halfvec(3)); +INSERT INTO t (val, val2) VALUES ('[0,0,0]', '[0,0,0]'), ('[1,2,3]', '[1,2,3]'), ('[1,1,1]', '[1,1,1]'), (NULL, NULL); -CREATE TABLE t2 (val vector(3)); +CREATE TABLE t2 (val vector(3), val2 halfvec(3)); \copy t TO 'results/data.bin' WITH (FORMAT binary) \copy t2 FROM 'results/data.bin' WITH (FORMAT binary) diff --git a/test/sql/halfvec_functions.sql b/test/sql/halfvec_functions.sql new file mode 100644 index 0000000..17f28a5 --- /dev/null +++ b/test/sql/halfvec_functions.sql @@ -0,0 +1,23 @@ +SELECT l2_distance('[0,0]'::halfvec, '[3,4]'); +SELECT l2_distance('[0,0]'::halfvec, '[0,1]'); +SELECT l2_distance('[1,2]'::halfvec, '[3]'); +SELECT '[0,0]'::halfvec <-> '[3,4]'; + +SELECT inner_product('[1,2]'::halfvec, '[3,4]'); +SELECT inner_product('[1,2]'::halfvec, '[3]'); +SELECT inner_product('[65504]'::halfvec, '[65504]'); +SELECT '[1,2]'::halfvec <#> '[3,4]'; + +SELECT cosine_distance('[1,2]'::halfvec, '[2,4]'); +SELECT cosine_distance('[1,2]'::halfvec, '[0,0]'); +SELECT cosine_distance('[1,1]'::halfvec, '[1,1]'); +SELECT cosine_distance('[1,0]'::halfvec, '[0,2]'); +SELECT cosine_distance('[1,1]'::halfvec, '[-1,-1]'); +SELECT cosine_distance('[1,2]'::halfvec, '[3]'); +SELECT cosine_distance('[1,1]'::halfvec, '[1.1,1.1]'); +SELECT cosine_distance('[1,1]'::halfvec, '[-1.1,-1.1]'); +SELECT '[1,2]'::halfvec <=> '[2,4]'; + +SELECT l1_distance('[0,0]'::halfvec, '[3,4]'); +SELECT l1_distance('[0,0]'::halfvec, '[0,1]'); +SELECT l1_distance('[1,2]'::halfvec, '[3]'); diff --git a/test/sql/halfvec_input.sql b/test/sql/halfvec_input.sql new file mode 100644 index 0000000..1ae3abd --- /dev/null +++ b/test/sql/halfvec_input.sql @@ -0,0 +1,34 @@ +SELECT '[1,2,3]'::halfvec; +SELECT '[-1,-2,-3]'::halfvec; +SELECT '[1.,2.,3.]'::halfvec; +SELECT ' [ 1, 2 , 3 ] '::halfvec; +SELECT '[1.23456]'::halfvec; +SELECT '[hello,1]'::halfvec; +SELECT '[NaN,1]'::halfvec; +SELECT '[Infinity,1]'::halfvec; +SELECT '[-Infinity,1]'::halfvec; +SELECT '[65519,-65519]'::halfvec; +SELECT '[65520,-65520]'::halfvec; +SELECT '[1e-8,-1e-8]'::halfvec; +SELECT '[4e38,1]'::halfvec; +SELECT '[1,2,3'::halfvec; +SELECT '[1,2,3]9'::halfvec; +SELECT '1,2,3'::halfvec; +SELECT ''::halfvec; +SELECT '['::halfvec; +SELECT '[,'::halfvec; +SELECT '[]'::halfvec; +SELECT '[1,]'::halfvec; +SELECT '[1a]'::halfvec; +SELECT '[1,,3]'::halfvec; +SELECT '[1, ,3]'::halfvec; + +SELECT '[1,2,3]'::halfvec(3); +SELECT '[1,2,3]'::halfvec(2); +SELECT '[1,2,3]'::halfvec(3, 2); +SELECT '[1,2,3]'::halfvec('a'); +SELECT '[1,2,3]'::halfvec(0); +SELECT '[1,2,3]'::halfvec(16001); + +SELECT unnest('{"[1,2,3]", "[4,5,6]"}'::halfvec[]); +SELECT '{"[1,2,3]"}'::halfvec(2)[]; diff --git a/test/sql/hnsw_halfvec_cosine.sql b/test/sql/hnsw_halfvec_cosine.sql new file mode 100644 index 0000000..e5473b4 --- /dev/null +++ b/test/sql/hnsw_halfvec_cosine.sql @@ -0,0 +1,13 @@ +SET enable_seqscan = off; + +CREATE TABLE t (val halfvec(3)); +INSERT INTO t (val) VALUES ('[0,0,0]'), ('[1,2,3]'), ('[1,1,1]'), (NULL); +CREATE INDEX ON t USING hnsw (val halfvec_cosine_ops); + +INSERT INTO t (val) VALUES ('[1,2,4]'); + +SELECT * FROM t ORDER BY val <=> '[3,3,3]'; +SELECT COUNT(*) FROM (SELECT * FROM t ORDER BY val <=> '[0,0,0]') t2; +SELECT COUNT(*) FROM (SELECT * FROM t ORDER BY val <=> (SELECT NULL::halfvec)) t2; + +DROP TABLE t; diff --git a/test/sql/hnsw_halfvec_ip.sql b/test/sql/hnsw_halfvec_ip.sql new file mode 100644 index 0000000..bdf32f7 --- /dev/null +++ b/test/sql/hnsw_halfvec_ip.sql @@ -0,0 +1,12 @@ +SET enable_seqscan = off; + +CREATE TABLE t (val halfvec(3)); +INSERT INTO t (val) VALUES ('[0,0,0]'), ('[1,2,3]'), ('[1,1,1]'), (NULL); +CREATE INDEX ON t USING hnsw (val halfvec_ip_ops); + +INSERT INTO t (val) VALUES ('[1,2,4]'); + +SELECT * FROM t ORDER BY val <#> '[3,3,3]'; +SELECT COUNT(*) FROM (SELECT * FROM t ORDER BY val <#> (SELECT NULL::halfvec)) t2; + +DROP TABLE t; diff --git a/test/sql/hnsw_halfvec_l2.sql b/test/sql/hnsw_halfvec_l2.sql new file mode 100644 index 0000000..bc20066 --- /dev/null +++ b/test/sql/hnsw_halfvec_l2.sql @@ -0,0 +1,16 @@ +SET enable_seqscan = off; + +CREATE TABLE t (val halfvec(3)); +INSERT INTO t (val) VALUES ('[0,0,0]'), ('[1,2,3]'), ('[1,1,1]'), (NULL); +CREATE INDEX ON t USING hnsw (val halfvec_l2_ops); + +INSERT INTO t (val) VALUES ('[1,2,4]'); + +SELECT * FROM t ORDER BY val <-> '[3,3,3]'; +SELECT COUNT(*) FROM (SELECT * FROM t ORDER BY val <-> (SELECT NULL::halfvec)) t2; +SELECT COUNT(*) FROM t; + +TRUNCATE t; +SELECT * FROM t ORDER BY val <-> '[3,3,3]'; + +DROP TABLE t; diff --git a/test/t/021_hnsw_halfvec_build_recall.pl b/test/t/021_hnsw_halfvec_build_recall.pl new file mode 100644 index 0000000..e1f3521 --- /dev/null +++ b/test/t/021_hnsw_halfvec_build_recall.pl @@ -0,0 +1,132 @@ +use strict; +use warnings; +use PostgresNode; +use TestLib; +use Test::More; + +my $node; +my @queries = (); +my @expected; +my $limit = 20; +my $dim = 10; +my $array_sql = join(",", ('random()') x $dim); + +sub test_recall +{ + my ($min, $operator) = @_; + my $correct = 0; + my $total = 0; + + my $explain = $node->safe_psql("postgres", qq( + SET enable_seqscan = off; + EXPLAIN ANALYZE SELECT i FROM tst ORDER BY v $operator '$queries[0]' LIMIT $limit; + )); + like($explain, qr/Index Scan/); + + for my $i (0 .. $#queries) + { + my $actual = $node->safe_psql("postgres", qq( + SET enable_seqscan = off; + SELECT i FROM tst ORDER BY v $operator '$queries[$i]' LIMIT $limit; + )); + my @actual_ids = split("\n", $actual); + my %actual_set = map { $_ => 1 } @actual_ids; + + my @expected_ids = split("\n", $expected[$i]); + + foreach (@expected_ids) + { + if (exists($actual_set{$_})) + { + $correct++; + } + $total++; + } + } + + cmp_ok($correct / $total, ">=", $min, $operator); +} + +# Initialize node +$node = get_new_node('node'); +$node->init; +$node->start; + +# Create table +$node->safe_psql("postgres", "CREATE EXTENSION vector;"); +$node->safe_psql("postgres", "CREATE TABLE tst (i int4, v halfvec($dim));"); +$node->safe_psql("postgres", + "INSERT INTO tst SELECT i, ARRAY[$array_sql] FROM generate_series(1, 10000) i;" +); + +# Generate queries +for (1 .. 20) +{ + my @r = (); + for (1 .. $dim) + { + push(@r, rand()); + } + push(@queries, "[" . join(",", @r) . "]"); +} + +# Check each index type +my @operators = ("<->", "<#>", "<=>"); +my @opclasses = ("halfvec_l2_ops", "halfvec_ip_ops", "halfvec_cosine_ops"); + +for my $i (0 .. $#operators) +{ + my $operator = $operators[$i]; + my $opclass = $opclasses[$i]; + + # Get exact results + @expected = (); + foreach (@queries) + { + my $res = $node->safe_psql("postgres", "SELECT i FROM tst ORDER BY v $operator '$_' LIMIT $limit;"); + push(@expected, $res); + } + + # Build index serially + $node->safe_psql("postgres", qq( + SET max_parallel_maintenance_workers = 0; + CREATE INDEX idx ON tst USING hnsw (v $opclass); + )); + + # Test approximate results + my $min = $operator eq "<#>" ? 0.95 : 0.99; + test_recall($min, $operator); + + $node->safe_psql("postgres", "DROP INDEX idx;"); + + # Build index in parallel in memory + my ($ret, $stdout, $stderr) = $node->psql("postgres", qq( + SET client_min_messages = DEBUG; + SET min_parallel_table_scan_size = 1; + CREATE INDEX idx ON tst USING hnsw (v $opclass); + )); + is($ret, 0, $stderr); + like($stderr, qr/using \d+ parallel workers/); + + # Test approximate results + test_recall($min, $operator); + + $node->safe_psql("postgres", "DROP INDEX idx;"); + + # Build index in parallel on disk + # Set parallel_workers on table to use workers with low maintenance_work_mem + ($ret, $stdout, $stderr) = $node->psql("postgres", qq( + ALTER TABLE tst SET (parallel_workers = 2); + SET client_min_messages = DEBUG; + SET maintenance_work_mem = '4MB'; + CREATE INDEX idx ON tst USING hnsw (v $opclass); + ALTER TABLE tst RESET (parallel_workers); + )); + is($ret, 0, $stderr); + like($stderr, qr/using \d+ parallel workers/); + like($stderr, qr/hnsw graph no longer fits into maintenance_work_mem/); + + $node->safe_psql("postgres", "DROP INDEX idx;"); +} + +done_testing();