diff --git a/CHANGELOG.md b/CHANGELOG.md index 9bcea19..9e3aa96 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,7 @@ +## 0.7.0 (unreleased) + +- Added `halfvec` type + ## 0.6.2 (2024-03-18) - Reduced lock contention with parallel HNSW index builds diff --git a/Makefile b/Makefile index d5f61ff..469840f 100644 --- a/Makefile +++ b/Makefile @@ -3,8 +3,8 @@ EXTVERSION = 0.6.2 MODULE_big = vector DATA = $(wildcard sql/*--*.sql) -OBJS = src/hnsw.o src/hnswbuild.o src/hnswinsert.o src/hnswscan.o src/hnswutils.o src/hnswvacuum.o src/ivfbuild.o src/ivfflat.o src/ivfinsert.o src/ivfkmeans.o src/ivfscan.o src/ivfutils.o src/ivfvacuum.o src/vector.o -HEADERS = src/vector.h +OBJS = src/halfvec.o src/hnsw.o src/hnswbuild.o src/hnswinsert.o src/hnswscan.o src/hnswutils.o src/hnswvacuum.o src/ivfbuild.o src/ivfflat.o src/ivfinsert.o src/ivfkmeans.o src/ivfscan.o src/ivfutils.o src/ivfvacuum.o src/vector.o +HEADERS = src/halfvec.h src/vector.h TESTS = $(wildcard test/sql/*.sql) REGRESS = $(patsubst test/sql/%.sql,%,$(TESTS)) diff --git a/Makefile.win b/Makefile.win index 1bb193e..b299a54 100644 --- a/Makefile.win +++ b/Makefile.win @@ -1,8 +1,8 @@ EXTENSION = vector EXTVERSION = 0.6.2 -OBJS = src\hnsw.obj src\hnswbuild.obj src\hnswinsert.obj src\hnswscan.obj src\hnswutils.obj src\hnswvacuum.obj src\ivfbuild.obj src\ivfflat.obj src\ivfinsert.obj src\ivfkmeans.obj src\ivfscan.obj src\ivfutils.obj src\ivfvacuum.obj src\vector.obj -HEADERS = src\vector.h +OBJS = src\halfvec.obj src\hnsw.obj src\hnswbuild.obj src\hnswinsert.obj src\hnswscan.obj src\hnswutils.obj src\hnswvacuum.obj src\ivfbuild.obj src\ivfflat.obj src\ivfinsert.obj src\ivfkmeans.obj src\ivfscan.obj src\ivfutils.obj src\ivfvacuum.obj src\vector.obj +HEADERS = src\halfvec.h src\vector.h REGRESS = btree cast copy functions input ivfflat_cosine ivfflat_ip ivfflat_l2 ivfflat_options ivfflat_unlogged REGRESS_OPTS = --inputdir=test --load-extension=$(EXTENSION) diff --git a/sql/vector--0.6.2--0.7.0.sql b/sql/vector--0.6.2--0.7.0.sql new file mode 100644 index 0000000..e0a7a04 --- /dev/null +++ b/sql/vector--0.6.2--0.7.0.sql @@ -0,0 +1,102 @@ +-- complain if script is sourced in psql, rather than via CREATE EXTENSION +\echo Use "ALTER EXTENSION vector UPDATE TO '0.7.0'" to load this file. \quit + +CREATE TYPE halfvec; + +CREATE FUNCTION halfvec_in(cstring, oid, integer) RETURNS halfvec + AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + +CREATE FUNCTION halfvec_out(halfvec) RETURNS cstring + AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + +CREATE FUNCTION halfvec_typmod_in(cstring[]) RETURNS integer + AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + +CREATE FUNCTION halfvec_recv(internal, oid, integer) RETURNS halfvec + AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + +CREATE FUNCTION halfvec_send(halfvec) RETURNS bytea + AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + +CREATE TYPE halfvec ( + INPUT = halfvec_in, + OUTPUT = halfvec_out, + TYPMOD_IN = halfvec_typmod_in, + RECEIVE = halfvec_recv, + SEND = halfvec_send, + STORAGE = external +); + +CREATE FUNCTION l2_distance(halfvec, halfvec) RETURNS float8 + AS 'MODULE_PATHNAME', 'halfvec_l2_distance' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + +CREATE FUNCTION inner_product(halfvec, halfvec) RETURNS float8 + AS 'MODULE_PATHNAME', 'halfvec_inner_product' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + +CREATE FUNCTION cosine_distance(halfvec, halfvec) RETURNS float8 + AS 'MODULE_PATHNAME', 'halfvec_cosine_distance' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + +CREATE FUNCTION l1_distance(halfvec, halfvec) RETURNS float8 + AS 'MODULE_PATHNAME', 'halfvec_l1_distance' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + +CREATE FUNCTION halfvec_l2_squared_distance(halfvec, halfvec) RETURNS float8 + AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + +CREATE FUNCTION halfvec_negative_inner_product(halfvec, halfvec) RETURNS float8 + AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + +CREATE FUNCTION halfvec(halfvec, integer, boolean) RETURNS halfvec + AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + +CREATE FUNCTION array_to_halfvec(integer[], integer, boolean) RETURNS halfvec + AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + +CREATE FUNCTION array_to_halfvec(real[], integer, boolean) RETURNS halfvec + AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + +CREATE FUNCTION array_to_halfvec(double precision[], integer, boolean) RETURNS halfvec + AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + +CREATE FUNCTION array_to_halfvec(numeric[], integer, boolean) RETURNS halfvec + AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + +CREATE FUNCTION halfvec_to_float4(halfvec, integer, boolean) RETURNS real[] + AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + +CREATE CAST (halfvec AS halfvec) + WITH FUNCTION halfvec(halfvec, integer, boolean) AS IMPLICIT; + +CREATE CAST (halfvec AS real[]) + WITH FUNCTION halfvec_to_float4(halfvec, integer, boolean) AS IMPLICIT; + +CREATE CAST (integer[] AS halfvec) + WITH FUNCTION array_to_halfvec(integer[], integer, boolean) AS ASSIGNMENT; + +CREATE CAST (real[] AS halfvec) + WITH FUNCTION array_to_halfvec(real[], integer, boolean) AS ASSIGNMENT; + +CREATE CAST (double precision[] AS halfvec) + WITH FUNCTION array_to_halfvec(double precision[], integer, boolean) AS ASSIGNMENT; + +CREATE CAST (numeric[] AS halfvec) + WITH FUNCTION array_to_halfvec(numeric[], integer, boolean) AS ASSIGNMENT; + +CREATE OPERATOR <-> ( + LEFTARG = halfvec, RIGHTARG = halfvec, PROCEDURE = l2_distance, + COMMUTATOR = '<->' +); + +CREATE OPERATOR <#> ( + LEFTARG = halfvec, RIGHTARG = halfvec, PROCEDURE = halfvec_negative_inner_product, + COMMUTATOR = '<#>' +); + +CREATE OPERATOR <=> ( + LEFTARG = halfvec, RIGHTARG = halfvec, PROCEDURE = cosine_distance, + COMMUTATOR = '<=>' +); + +CREATE OPERATOR CLASS halfvec_l2_ops + FOR TYPE halfvec USING hnsw AS + OPERATOR 1 <-> (halfvec, halfvec) FOR ORDER BY float_ops, + FUNCTION 1 halfvec_l2_squared_distance(halfvec, halfvec); diff --git a/sql/vector.sql b/sql/vector.sql index 141e83c..9b78384 100644 --- a/sql/vector.sql +++ b/sql/vector.sql @@ -287,3 +287,131 @@ CREATE OPERATOR CLASS vector_cosine_ops OPERATOR 1 <=> (vector, vector) FOR ORDER BY float_ops, FUNCTION 1 vector_negative_inner_product(vector, vector), FUNCTION 2 vector_norm(vector); + +-- halfvec type + +CREATE TYPE halfvec; + +CREATE FUNCTION halfvec_in(cstring, oid, integer) RETURNS halfvec + AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + +CREATE FUNCTION halfvec_out(halfvec) RETURNS cstring + AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + +CREATE FUNCTION halfvec_typmod_in(cstring[]) RETURNS integer + AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + +CREATE FUNCTION halfvec_recv(internal, oid, integer) RETURNS halfvec + AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + +CREATE FUNCTION halfvec_send(halfvec) RETURNS bytea + AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + +CREATE TYPE halfvec ( + INPUT = halfvec_in, + OUTPUT = halfvec_out, + TYPMOD_IN = halfvec_typmod_in, + RECEIVE = halfvec_recv, + SEND = halfvec_send, + STORAGE = external +); + +-- halfvec functions + +CREATE FUNCTION l2_distance(halfvec, halfvec) RETURNS float8 + AS 'MODULE_PATHNAME', 'halfvec_l2_distance' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + +CREATE FUNCTION inner_product(halfvec, halfvec) RETURNS float8 + AS 'MODULE_PATHNAME', 'halfvec_inner_product' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + +CREATE FUNCTION cosine_distance(halfvec, halfvec) RETURNS float8 + AS 'MODULE_PATHNAME', 'halfvec_cosine_distance' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + +CREATE FUNCTION l1_distance(halfvec, halfvec) RETURNS float8 + AS 'MODULE_PATHNAME', 'halfvec_l1_distance' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + +CREATE FUNCTION halfvec_norm(halfvec) RETURNS float8 + AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + +-- halfvec private functions + +CREATE FUNCTION halfvec_l2_squared_distance(halfvec, halfvec) RETURNS float8 + AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + +CREATE FUNCTION halfvec_negative_inner_product(halfvec, halfvec) RETURNS float8 + AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + +-- halfvec cast functions + +CREATE FUNCTION halfvec(halfvec, integer, boolean) RETURNS halfvec + AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + +CREATE FUNCTION array_to_halfvec(integer[], integer, boolean) RETURNS halfvec + AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + +CREATE FUNCTION array_to_halfvec(real[], integer, boolean) RETURNS halfvec + AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + +CREATE FUNCTION array_to_halfvec(double precision[], integer, boolean) RETURNS halfvec + AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + +CREATE FUNCTION array_to_halfvec(numeric[], integer, boolean) RETURNS halfvec + AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + +CREATE FUNCTION halfvec_to_float4(halfvec, integer, boolean) RETURNS real[] + AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + +-- halfvec casts + +CREATE CAST (halfvec AS halfvec) + WITH FUNCTION halfvec(halfvec, integer, boolean) AS IMPLICIT; + +CREATE CAST (halfvec AS real[]) + WITH FUNCTION halfvec_to_float4(halfvec, integer, boolean) AS IMPLICIT; + +CREATE CAST (integer[] AS halfvec) + WITH FUNCTION array_to_halfvec(integer[], integer, boolean) AS ASSIGNMENT; + +CREATE CAST (real[] AS halfvec) + WITH FUNCTION array_to_halfvec(real[], integer, boolean) AS ASSIGNMENT; + +CREATE CAST (double precision[] AS halfvec) + WITH FUNCTION array_to_halfvec(double precision[], integer, boolean) AS ASSIGNMENT; + +CREATE CAST (numeric[] AS halfvec) + WITH FUNCTION array_to_halfvec(numeric[], integer, boolean) AS ASSIGNMENT; + +-- halfvec operators + +CREATE OPERATOR <-> ( + LEFTARG = halfvec, RIGHTARG = halfvec, PROCEDURE = l2_distance, + COMMUTATOR = '<->' +); + +CREATE OPERATOR <#> ( + LEFTARG = halfvec, RIGHTARG = halfvec, PROCEDURE = halfvec_negative_inner_product, + COMMUTATOR = '<#>' +); + +CREATE OPERATOR <=> ( + LEFTARG = halfvec, RIGHTARG = halfvec, PROCEDURE = cosine_distance, + COMMUTATOR = '<=>' +); + +-- halfvec opclasses + +CREATE OPERATOR CLASS halfvec_l2_ops + FOR TYPE halfvec USING hnsw AS + OPERATOR 1 <-> (halfvec, halfvec) FOR ORDER BY float_ops, + FUNCTION 1 halfvec_l2_squared_distance(halfvec, halfvec); + +CREATE OPERATOR CLASS halfvec_ip_ops + FOR TYPE halfvec USING hnsw AS + OPERATOR 1 <#> (halfvec, halfvec) FOR ORDER BY float_ops, + FUNCTION 1 halfvec_negative_inner_product(halfvec, halfvec); + +CREATE OPERATOR CLASS halfvec_cosine_ops + FOR TYPE halfvec USING hnsw AS + OPERATOR 1 <=> (halfvec, halfvec) FOR ORDER BY float_ops, + FUNCTION 1 halfvec_negative_inner_product(halfvec, halfvec), + FUNCTION 2 halfvec_norm(halfvec); diff --git a/src/halfvec.c b/src/halfvec.c new file mode 100644 index 0000000..13389bf --- /dev/null +++ b/src/halfvec.c @@ -0,0 +1,901 @@ +#include "postgres.h" + +#include + +#include "catalog/pg_type.h" +#include "common/shortest_dec.h" +#include "fmgr.h" +#include "halfvec.h" +#include "lib/stringinfo.h" +#include "libpq/pqformat.h" +#include "port.h" /* for strtof() */ +#include "utils/array.h" +#include "utils/builtins.h" +#include "utils/float.h" +#include "utils/lsyscache.h" +#include "utils/numeric.h" + +/* + * Check if half is NaN + */ +static inline bool +HalfIsNan(half num) +{ +#ifdef FLT16_SUPPORT + return isnan(num); +#else + return (num & 0x7C00) == 0x7C00 && (num & 0x7FFF) != 0x7C00; +#endif +} + +/* + * Check if half is infinite + */ +static inline bool +HalfIsInf(half num) +{ +#ifdef FLT16_SUPPORT + return isinf(num); +#else + return (num & 0x7FFF) == 0x7C00; +#endif +} + +/* + * Check if half is zero + */ +static inline bool +HalfIsZero(half num) +{ +#ifdef FLT16_SUPPORT + return num == 0; +#else + return (num & 0x7FFF) == 0x0000; +#endif +} + +/* + * Get a half from a message buffer + */ +static half +pq_getmsghalf(StringInfo msg) +{ + union + { + half h; + uint16 i; + } swap; + + swap.i = pq_getmsgint(msg, 2); + return swap.h; +} + +/* + * Append a half to a StringInfo buffer + */ +static void +pq_sendhalf(StringInfo buf, half h) +{ + union + { + half h; + uint16 i; + } swap; + + swap.h = h; + pq_sendint16(buf, swap.i); +} + +/* + * Convert a half to a float4 + */ +static float +HalfToFloat4(half num) +{ +#ifdef FLT16_SUPPORT + return (float) num; +#else + /* TODO Improve performance */ + + /* Assumes same endianness for floats and integers */ + /* TODO Use union to swap */ + uint16 bin = *((uint16 *) &num); + uint32 exponent = (bin & 0x7C00) >> 10; + uint32 mantissa = bin & 0x03FF; + + /* Sign */ + uint32 result = (bin & 0x8000) << 16; + + if (exponent == 31) + { + if (mantissa == 0) + { + /* Infinite */ + result |= 0x7F800000; + } + else + { + /* NaN */ + result |= 0x7FC00000; + result |= mantissa << 13; + } + } + else if (exponent == 0) + { + /* Subnormal */ + if (mantissa != 0) + { + exponent = -14; + + for (int i = 0; i < 10; i++) + { + mantissa <<= 1; + exponent -= 1; + + if ((mantissa >> 10) % 2 == 1) + { + mantissa &= 0x03ff; + break; + } + } + + result |= (exponent + 127) << 23; + result |= mantissa << 13; + } + } + else + { + /* Normal */ + result |= (exponent - 15 + 127) << 23; + result |= mantissa << 13; + } + + /* TODO Use union to swap */ + return *((float *) &result); +#endif +} + +/* + * Convert a float4 to a half + */ +static half +Float4ToHalfUnchecked(float num) +{ +#ifdef FLT16_SUPPORT + return (_Float16) num; +#else + /* TODO Improve performance */ + + /* Assumes same endianness for floats and integers */ + /* TODO Use union to swap */ + uint32 bin = *((uint32 *) &num); + int exponent = (bin & 0x7F800000) >> 23; + int mantissa = bin & 0x007FFFFF; + + /* Sign */ + uint16 result = (bin & 0x80000000) >> 16; + + if (isinf(num)) + { + /* Infinite */ + result |= 0x7C00; + } + else if (isnan(num)) + { + /* NaN */ + result |= 0x7E00; + result |= mantissa >> 13; + } + else if (exponent > 98) + { + int m; + int gr; + int s; + + exponent -= 127; + s = mantissa & 0x00000FFF; + + /* Subnormal */ + if (exponent < -14) + { + int diff = -exponent - 14; + + mantissa >>= diff; + mantissa += 1 << (23 - diff); + s |= mantissa & 0x00000FFF; + } + + m = mantissa >> 13; + + /* Round */ + gr = (mantissa >> 12) % 4; + if (gr == 3 || (gr == 1 && s != 0)) + m += 1; + + if (m == 1024) + { + m = 0; + exponent += 1; + } + + if (exponent > 15) + { + /* Infinite */ + result |= 0x7C00; + } + else + { + if (exponent >= -14) + result |= (exponent + 15) << 10; + + result |= m; + } + } + + /* TODO Use union to swap */ + return *((half *) & result); +#endif +} + +/* + * Convert a float4 to a half + */ +static half +Float4ToHalf(float num) +{ + half result = Float4ToHalfUnchecked(num); + + if (unlikely(HalfIsInf(result)) && !isinf(num)) + float_overflow_error(); + if (unlikely(HalfIsZero(result)) && num != 0.0) + float_underflow_error(); + + return result; +} + +/* + * Ensure same dimensions + */ +static inline void +CheckDims(HalfVector * a, HalfVector * b) +{ + if (a->dim != b->dim) + ereport(ERROR, + (errcode(ERRCODE_DATA_EXCEPTION), + errmsg("different halfvec dimensions %d and %d", a->dim, b->dim))); +} + +/* + * Ensure expected dimensions + */ +static inline void +CheckExpectedDim(int32 typmod, int dim) +{ + if (typmod != -1 && typmod != dim) + ereport(ERROR, + (errcode(ERRCODE_DATA_EXCEPTION), + errmsg("expected %d dimensions, not %d", typmod, dim))); +} + +/* + * Ensure valid dimensions + */ +static inline void +CheckDim(int dim) +{ + if (dim < 1) + ereport(ERROR, + (errcode(ERRCODE_DATA_EXCEPTION), + errmsg("halfvec must have at least 1 dimension"))); + + if (dim > HALFVEC_MAX_DIM) + ereport(ERROR, + (errcode(ERRCODE_PROGRAM_LIMIT_EXCEEDED), + errmsg("halfvec cannot have more than %d dimensions", HALFVEC_MAX_DIM))); +} + +/* + * Ensure finite element + */ +static inline void +CheckElement(half value) +{ + if (HalfIsNan(value)) + ereport(ERROR, + (errcode(ERRCODE_DATA_EXCEPTION), + errmsg("NaN not allowed in halfvec"))); + + if (HalfIsInf(value)) + ereport(ERROR, + (errcode(ERRCODE_DATA_EXCEPTION), + errmsg("infinite value not allowed in halfvec"))); +} + +/* + * Allocate and initialize a new half vector + */ +HalfVector * +InitHalfVector(int dim) +{ + HalfVector *result; + int size; + + size = HALFVEC_SIZE(dim); + result = (HalfVector *) palloc0(size); + SET_VARSIZE(result, size); + result->dim = dim; + + return result; +} + +/* + * Check for whitespace, since array_isspace() is static + */ +static inline bool +halfvec_isspace(char ch) +{ + if (ch == ' ' || + ch == '\t' || + ch == '\n' || + ch == '\r' || + ch == '\v' || + ch == '\f') + return true; + return false; +} + +#if PG_VERSION_NUM < 120003 +static pg_noinline void +float_overflow_error(void) +{ + ereport(ERROR, + (errcode(ERRCODE_NUMERIC_VALUE_OUT_OF_RANGE), + errmsg("value out of range: overflow"))); +} + +static pg_noinline void +float_underflow_error(void) +{ + ereport(ERROR, + (errcode(ERRCODE_NUMERIC_VALUE_OUT_OF_RANGE), + errmsg("value out of range: underflow"))); +} +#endif + +/* + * Convert textual representation to internal representation + */ +PGDLLEXPORT PG_FUNCTION_INFO_V1(halfvec_in); +Datum +halfvec_in(PG_FUNCTION_ARGS) +{ + char *lit = PG_GETARG_CSTRING(0); + int32 typmod = PG_GETARG_INT32(2); + half x[HALFVEC_MAX_DIM]; + int dim = 0; + char *pt; + char *stringEnd; + HalfVector *result; + char *litcopy = pstrdup(lit); + char *str = litcopy; + + while (halfvec_isspace(*str)) + str++; + + if (*str != '[') + ereport(ERROR, + (errcode(ERRCODE_INVALID_TEXT_REPRESENTATION), + errmsg("malformed halfvec literal: \"%s\"", lit), + errdetail("Vector contents must start with \"[\"."))); + + str++; + pt = strtok(str, ","); + stringEnd = pt; + + while (pt != NULL && *stringEnd != ']') + { + if (dim == HALFVEC_MAX_DIM) + ereport(ERROR, + (errcode(ERRCODE_PROGRAM_LIMIT_EXCEEDED), + errmsg("halfvec cannot have more than %d dimensions", HALFVEC_MAX_DIM))); + + while (halfvec_isspace(*pt)) + pt++; + + /* Check for empty string like float4in */ + if (*pt == '\0') + ereport(ERROR, + (errcode(ERRCODE_INVALID_TEXT_REPRESENTATION), + errmsg("invalid input syntax for type halfvec: \"%s\"", lit))); + + /* Use strtof like float4in to avoid a double-rounding problem */ + x[dim] = Float4ToHalf(strtof(pt, &stringEnd)); + CheckElement(x[dim]); + dim++; + + if (stringEnd == pt) + ereport(ERROR, + (errcode(ERRCODE_INVALID_TEXT_REPRESENTATION), + errmsg("invalid input syntax for type halfvec: \"%s\"", lit))); + + while (halfvec_isspace(*stringEnd)) + stringEnd++; + + if (*stringEnd != '\0' && *stringEnd != ']') + ereport(ERROR, + (errcode(ERRCODE_INVALID_TEXT_REPRESENTATION), + errmsg("invalid input syntax for type halfvec: \"%s\"", lit))); + + pt = strtok(NULL, ","); + } + + if (stringEnd == NULL || *stringEnd != ']') + ereport(ERROR, + (errcode(ERRCODE_INVALID_TEXT_REPRESENTATION), + errmsg("malformed halfvec literal: \"%s\"", lit), + errdetail("Unexpected end of input."))); + + stringEnd++; + + /* Only whitespace is allowed after the closing brace */ + while (halfvec_isspace(*stringEnd)) + stringEnd++; + + if (*stringEnd != '\0') + ereport(ERROR, + (errcode(ERRCODE_INVALID_TEXT_REPRESENTATION), + errmsg("malformed halfvec literal: \"%s\"", lit), + errdetail("Junk after closing right brace."))); + + /* Ensure no consecutive delimiters since strtok skips */ + for (pt = lit + 1; *pt != '\0'; pt++) + { + if (pt[-1] == ',' && *pt == ',') + ereport(ERROR, + (errcode(ERRCODE_INVALID_TEXT_REPRESENTATION), + errmsg("malformed halfvec literal: \"%s\"", lit))); + } + + if (dim < 1) + ereport(ERROR, + (errcode(ERRCODE_DATA_EXCEPTION), + errmsg("halfvec must have at least 1 dimension"))); + + pfree(litcopy); + + CheckExpectedDim(typmod, dim); + + result = InitHalfVector(dim); + for (int i = 0; i < dim; i++) + result->x[i] = x[i]; + + PG_RETURN_POINTER(result); +} + +/* + * Convert internal representation to textual representation + */ +PGDLLEXPORT PG_FUNCTION_INFO_V1(halfvec_out); +Datum +halfvec_out(PG_FUNCTION_ARGS) +{ + HalfVector *vector = PG_GETARG_HALFVEC_P(0); + int dim = vector->dim; + char *buf; + char *ptr; + int n; + + /* + * Need: + * + * dim * (FLOAT_SHORTEST_DECIMAL_LEN - 1) bytes for + * float_to_shortest_decimal_bufn + * + * dim - 1 bytes for separator + * + * 3 bytes for [, ], and \0 + */ + buf = (char *) palloc(FLOAT_SHORTEST_DECIMAL_LEN * dim + 2); + ptr = buf; + + *ptr = '['; + ptr++; + for (int i = 0; i < dim; i++) + { + if (i > 0) + { + *ptr = ','; + ptr++; + } + + n = float_to_shortest_decimal_bufn(HalfToFloat4(vector->x[i]), ptr); + ptr += n; + } + *ptr = ']'; + ptr++; + *ptr = '\0'; + + PG_FREE_IF_COPY(vector, 0); + PG_RETURN_CSTRING(buf); +} + +/* + * Convert type modifier + */ +PGDLLEXPORT PG_FUNCTION_INFO_V1(halfvec_typmod_in); +Datum +halfvec_typmod_in(PG_FUNCTION_ARGS) +{ + ArrayType *ta = PG_GETARG_ARRAYTYPE_P(0); + int32 *tl; + int n; + + tl = ArrayGetIntegerTypmods(ta, &n); + + if (n != 1) + ereport(ERROR, + (errcode(ERRCODE_INVALID_PARAMETER_VALUE), + errmsg("invalid type modifier"))); + + if (*tl < 1) + ereport(ERROR, + (errcode(ERRCODE_INVALID_PARAMETER_VALUE), + errmsg("dimensions for type halfvec must be at least 1"))); + + if (*tl > HALFVEC_MAX_DIM) + ereport(ERROR, + (errcode(ERRCODE_INVALID_PARAMETER_VALUE), + errmsg("dimensions for type halfvec cannot exceed %d", HALFVEC_MAX_DIM))); + + PG_RETURN_INT32(*tl); +} + +/* + * Convert external binary representation to internal representation + */ +PGDLLEXPORT PG_FUNCTION_INFO_V1(halfvec_recv); +Datum +halfvec_recv(PG_FUNCTION_ARGS) +{ + StringInfo buf = (StringInfo) PG_GETARG_POINTER(0); + int32 typmod = PG_GETARG_INT32(2); + HalfVector *result; + int16 dim; + int16 unused; + + dim = pq_getmsgint(buf, sizeof(int16)); + unused = pq_getmsgint(buf, sizeof(int16)); + + CheckDim(dim); + CheckExpectedDim(typmod, dim); + + if (unused != 0) + ereport(ERROR, + (errcode(ERRCODE_DATA_EXCEPTION), + errmsg("expected unused to be 0, not %d", unused))); + + result = InitHalfVector(dim); + for (int i = 0; i < dim; i++) + { + result->x[i] = pq_getmsghalf(buf); + CheckElement(result->x[i]); + } + + PG_RETURN_POINTER(result); +} + +/* + * Convert internal representation to the external binary representation + */ +PGDLLEXPORT PG_FUNCTION_INFO_V1(halfvec_send); +Datum +halfvec_send(PG_FUNCTION_ARGS) +{ + HalfVector *vec = PG_GETARG_HALFVEC_P(0); + StringInfoData buf; + + pq_begintypsend(&buf); + pq_sendint(&buf, vec->dim, sizeof(int16)); + pq_sendint(&buf, vec->unused, sizeof(int16)); + for (int i = 0; i < vec->dim; i++) + pq_sendhalf(&buf, vec->x[i]); + + PG_RETURN_BYTEA_P(pq_endtypsend(&buf)); +} + +/* + * Convert half vector to half vector + * This is needed to check the type modifier + */ +PGDLLEXPORT PG_FUNCTION_INFO_V1(halfvec); +Datum +halfvec(PG_FUNCTION_ARGS) +{ + HalfVector *vec = PG_GETARG_HALFVEC_P(0); + int32 typmod = PG_GETARG_INT32(1); + + CheckExpectedDim(typmod, vec->dim); + + PG_RETURN_POINTER(vec); +} + +/* + * Convert array to half vector + */ +PGDLLEXPORT PG_FUNCTION_INFO_V1(array_to_halfvec); +Datum +array_to_halfvec(PG_FUNCTION_ARGS) +{ + ArrayType *array = PG_GETARG_ARRAYTYPE_P(0); + int32 typmod = PG_GETARG_INT32(1); + HalfVector *result; + int16 typlen; + bool typbyval; + char typalign; + Datum *elemsp; + int nelemsp; + + if (ARR_NDIM(array) > 1) + ereport(ERROR, + (errcode(ERRCODE_DATA_EXCEPTION), + errmsg("array must be 1-D"))); + + if (ARR_HASNULL(array) && array_contains_nulls(array)) + ereport(ERROR, + (errcode(ERRCODE_NULL_VALUE_NOT_ALLOWED), + errmsg("array must not contain nulls"))); + + get_typlenbyvalalign(ARR_ELEMTYPE(array), &typlen, &typbyval, &typalign); + deconstruct_array(array, ARR_ELEMTYPE(array), typlen, typbyval, typalign, &elemsp, NULL, &nelemsp); + + CheckDim(nelemsp); + CheckExpectedDim(typmod, nelemsp); + + result = InitHalfVector(nelemsp); + + if (ARR_ELEMTYPE(array) == INT4OID) + { + for (int i = 0; i < nelemsp; i++) + result->x[i] = Float4ToHalf(DatumGetInt32(elemsp[i])); + } + else if (ARR_ELEMTYPE(array) == FLOAT8OID) + { + for (int i = 0; i < nelemsp; i++) + result->x[i] = Float4ToHalf(DatumGetFloat8(elemsp[i])); + } + else if (ARR_ELEMTYPE(array) == FLOAT4OID) + { + for (int i = 0; i < nelemsp; i++) + result->x[i] = Float4ToHalf(DatumGetFloat4(elemsp[i])); + } + else if (ARR_ELEMTYPE(array) == NUMERICOID) + { + for (int i = 0; i < nelemsp; i++) + result->x[i] = Float4ToHalf(DatumGetFloat4(DirectFunctionCall1(numeric_float4, elemsp[i]))); + } + else + { + ereport(ERROR, + (errcode(ERRCODE_DATA_EXCEPTION), + errmsg("unsupported array type"))); + } + + /* + * Free allocation from deconstruct_array. Do not free individual elements + * when pass-by-reference since they point to original array. + */ + pfree(elemsp); + + /* Check elements */ + for (int i = 0; i < result->dim; i++) + CheckElement(result->x[i]); + + PG_RETURN_POINTER(result); +} + +/* + * Convert half vector to float4[] + */ +PGDLLEXPORT PG_FUNCTION_INFO_V1(halfvec_to_float4); +Datum +halfvec_to_float4(PG_FUNCTION_ARGS) +{ + HalfVector *vec = PG_GETARG_HALFVEC_P(0); + Datum *datums; + ArrayType *result; + + datums = (Datum *) palloc(sizeof(Datum) * vec->dim); + + for (int i = 0; i < vec->dim; i++) + datums[i] = Float4GetDatum(HalfToFloat4(vec->x[i])); + + /* Use TYPALIGN_INT for float4 */ + result = construct_array(datums, vec->dim, FLOAT4OID, sizeof(float4), true, TYPALIGN_INT); + + pfree(datums); + + PG_RETURN_POINTER(result); +} + +/* + * Get the L2 distance between half vectors + */ +PGDLLEXPORT PG_FUNCTION_INFO_V1(halfvec_l2_distance); +Datum +halfvec_l2_distance(PG_FUNCTION_ARGS) +{ + HalfVector *a = PG_GETARG_HALFVEC_P(0); + HalfVector *b = PG_GETARG_HALFVEC_P(1); + half *ax = a->x; + half *bx = b->x; + float distance = 0.0; + + CheckDims(a, b); + + /* Auto-vectorized */ + for (int i = 0; i < a->dim; i++) + { + float diff = HalfToFloat4(ax[i]) - HalfToFloat4(bx[i]); + + distance += diff * diff; + } + + PG_RETURN_FLOAT8(sqrt((double) distance)); +} + +/* + * Get the L2 squared distance between half vectors + */ +PGDLLEXPORT PG_FUNCTION_INFO_V1(halfvec_l2_squared_distance); +Datum +halfvec_l2_squared_distance(PG_FUNCTION_ARGS) +{ + HalfVector *a = PG_GETARG_HALFVEC_P(0); + HalfVector *b = PG_GETARG_HALFVEC_P(1); + half *ax = a->x; + half *bx = b->x; + float distance = 0.0; + + CheckDims(a, b); + + /* Auto-vectorized */ + for (int i = 0; i < a->dim; i++) + { + float diff = HalfToFloat4(ax[i]) - HalfToFloat4(bx[i]); + + distance += diff * diff; + } + + PG_RETURN_FLOAT8((double) distance); +} + +/* + * Get the inner product of two half vectors + */ +PGDLLEXPORT PG_FUNCTION_INFO_V1(halfvec_inner_product); +Datum +halfvec_inner_product(PG_FUNCTION_ARGS) +{ + HalfVector *a = PG_GETARG_HALFVEC_P(0); + HalfVector *b = PG_GETARG_HALFVEC_P(1); + half *ax = a->x; + half *bx = b->x; + float distance = 0.0; + + CheckDims(a, b); + + /* Auto-vectorized */ + for (int i = 0; i < a->dim; i++) + distance += HalfToFloat4(ax[i]) * HalfToFloat4(bx[i]); + + PG_RETURN_FLOAT8((double) distance); +} + +/* + * Get the negative inner product of two half vectors + */ +PGDLLEXPORT PG_FUNCTION_INFO_V1(halfvec_negative_inner_product); +Datum +halfvec_negative_inner_product(PG_FUNCTION_ARGS) +{ + HalfVector *a = PG_GETARG_HALFVEC_P(0); + HalfVector *b = PG_GETARG_HALFVEC_P(1); + half *ax = a->x; + half *bx = b->x; + float distance = 0.0; + + CheckDims(a, b); + + /* Auto-vectorized */ + for (int i = 0; i < a->dim; i++) + distance += HalfToFloat4(ax[i]) * HalfToFloat4(bx[i]); + + PG_RETURN_FLOAT8((double) distance * -1); +} + +/* + * Get the cosine distance between two half vectors + */ +PGDLLEXPORT PG_FUNCTION_INFO_V1(halfvec_cosine_distance); +Datum +halfvec_cosine_distance(PG_FUNCTION_ARGS) +{ + HalfVector *a = PG_GETARG_HALFVEC_P(0); + HalfVector *b = PG_GETARG_HALFVEC_P(1); + half *ax = a->x; + half *bx = b->x; + float distance = 0.0; + float norma = 0.0; + float normb = 0.0; + double similarity; + + CheckDims(a, b); + + /* Auto-vectorized */ + for (int i = 0; i < a->dim; i++) + { + float axi = HalfToFloat4(ax[i]); + float bxi = HalfToFloat4(bx[i]); + + distance += axi * bxi; + norma += axi * axi; + normb += bxi * bxi; + } + + /* Use sqrt(a * b) over sqrt(a) * sqrt(b) */ + similarity = (double) distance / sqrt((double) norma * (double) normb); + +#ifdef _MSC_VER + /* /fp:fast may not propagate NaN */ + if (isnan(similarity)) + PG_RETURN_FLOAT8(NAN); +#endif + + /* Keep in range */ + if (similarity > 1) + similarity = 1; + else if (similarity < -1) + similarity = -1; + + PG_RETURN_FLOAT8(1 - similarity); +} + +/* + * Get the L1 distance between two half vectors + */ +PGDLLEXPORT PG_FUNCTION_INFO_V1(halfvec_l1_distance); +Datum +halfvec_l1_distance(PG_FUNCTION_ARGS) +{ + HalfVector *a = PG_GETARG_HALFVEC_P(0); + HalfVector *b = PG_GETARG_HALFVEC_P(1); + half *ax = a->x; + half *bx = b->x; + float distance = 0.0; + + CheckDims(a, b); + + /* Auto-vectorized */ + for (int i = 0; i < a->dim; i++) + distance += fabsf(HalfToFloat4(ax[i]) - HalfToFloat4(bx[i])); + + PG_RETURN_FLOAT8((double) distance); +} + +/* + * Get the L2 norm of a half vector + */ +PGDLLEXPORT PG_FUNCTION_INFO_V1(halfvec_norm); +Datum +halfvec_norm(PG_FUNCTION_ARGS) +{ + HalfVector *a = PG_GETARG_HALFVEC_P(0); + half *ax = a->x; + double norm = 0.0; + + /* Auto-vectorized */ + for (int i = 0; i < a->dim; i++) + norm += (double) HalfToFloat4(ax[i]) * (double) HalfToFloat4(ax[i]); + + PG_RETURN_FLOAT8(sqrt(norm)); +} diff --git a/src/halfvec.h b/src/halfvec.h new file mode 100644 index 0000000..21ef042 --- /dev/null +++ b/src/halfvec.h @@ -0,0 +1,38 @@ +#ifndef HALFVEC_H +#define HALFVEC_H + +#define __STDC_WANT_IEC_60559_TYPES_EXT__ + +#include + +#ifdef __FLT16_MAX__ +#define FLT16_SUPPORT +#endif + +#ifdef FLT16_SUPPORT +#define half _Float16 +#define HALF_MAX FLT16_MAX +#else +/* TODO #pragma message("")? */ +#define half uint16 +#define HALF_MAX 65504 +#endif + +#define HALFVEC_MAX_DIM 32000 + +#define HALFVEC_SIZE(_dim) (offsetof(HalfVector, x) + sizeof(half)*(_dim)) +#define DatumGetHalfVector(x) ((HalfVector *) PG_DETOAST_DATUM(x)) +#define PG_GETARG_HALFVEC_P(x) DatumGetHalfVector(PG_GETARG_DATUM(x)) +#define PG_RETURN_HALFVEC_P(x) PG_RETURN_POINTER(x) + +typedef struct HalfVector +{ + int32 vl_len_; /* varlena header (do not touch directly!) */ + int16 dim; /* number of dimensions */ + int16 unused; + half x[FLEXIBLE_ARRAY_MEMBER]; +} HalfVector; + +HalfVector *InitHalfVector(int dim); + +#endif diff --git a/src/hnsw.h b/src/hnsw.h index fed281e..7b4659e 100644 --- a/src/hnsw.h +++ b/src/hnsw.h @@ -55,6 +55,10 @@ #define HNSW_UPDATE_ENTRY_GREATER 1 #define HNSW_UPDATE_ENTRY_ALWAYS 2 +/* Data types */ +#define HNSW_TYPE_VECTOR 1 +#define HNSW_TYPE_HALFVEC 2 + /* Build phases */ /* PROGRESS_CREATEIDX_SUBPHASE_INITIALIZE is 1 */ #define PROGRESS_HNSW_PHASE_LOAD 2 @@ -242,6 +246,7 @@ typedef struct HnswBuildState Relation index; IndexInfo *indexInfo; ForkNumber forkNum; + int type; /* Settings */ int dimensions; @@ -262,7 +267,6 @@ typedef struct HnswBuildState HnswGraph *graph; double ml; int maxLevel; - Vector *normvec; /* Memory */ MemoryContext graphCtx; @@ -367,7 +371,8 @@ typedef struct HnswVacuumState int HnswGetM(Relation index); int HnswGetEfConstruction(Relation index); FmgrInfo *HnswOptionalProcInfo(Relation index, uint16 procnum); -bool HnswNormValue(FmgrInfo *procinfo, Oid collation, Datum *value, Vector * result); +int HnswGetType(Relation index); +bool HnswNormValue(FmgrInfo *procinfo, Oid collation, Datum *value, int type); Buffer HnswNewBuffer(Relation index, ForkNumber forkNum); void HnswInitPage(Buffer buf, Page page); void HnswInit(void); diff --git a/src/hnswbuild.c b/src/hnswbuild.c index a9c737c..a1a4e13 100644 --- a/src/hnswbuild.c +++ b/src/hnswbuild.c @@ -489,7 +489,7 @@ InsertTuple(Relation index, Datum *values, bool *isnull, ItemPointer heaptid, Hn /* Normalize if needed */ if (buildstate->normprocinfo != NULL) { - if (!HnswNormValue(buildstate->normprocinfo, buildstate->collation, &value, buildstate->normvec)) + if (!HnswNormValue(buildstate->normprocinfo, buildstate->collation, &value, buildstate->type)) return false; } @@ -671,21 +671,27 @@ HnswSharedMemoryAlloc(Size size, void *state) static void InitBuildState(HnswBuildState * buildstate, Relation heap, Relation index, IndexInfo *indexInfo, ForkNumber forkNum) { + int maxDimensions = HNSW_MAX_DIM; + buildstate->heap = heap; buildstate->index = index; buildstate->indexInfo = indexInfo; buildstate->forkNum = forkNum; + buildstate->type = HnswGetType(index); buildstate->m = HnswGetM(index); buildstate->efConstruction = HnswGetEfConstruction(index); buildstate->dimensions = TupleDescAttr(index->rd_att, 0)->atttypmod; + if (buildstate->type == HNSW_TYPE_HALFVEC) + maxDimensions *= 2; + /* Require column to have dimensions to be indexed */ if (buildstate->dimensions < 0) elog(ERROR, "column does not have dimensions"); - if (buildstate->dimensions > HNSW_MAX_DIM) - elog(ERROR, "column cannot have more than %d dimensions for hnsw index", HNSW_MAX_DIM); + if (buildstate->dimensions > maxDimensions) + elog(ERROR, "column cannot have more than %d dimensions for hnsw index", maxDimensions); if (buildstate->efConstruction < 2 * buildstate->m) elog(ERROR, "ef_construction must be greater than or equal to 2 * m"); @@ -703,9 +709,6 @@ InitBuildState(HnswBuildState * buildstate, Relation heap, Relation index, Index buildstate->ml = HnswGetMl(buildstate->m); buildstate->maxLevel = HnswGetMaxLevel(buildstate->m); - /* Reuse for each tuple */ - buildstate->normvec = InitVector(buildstate->dimensions); - buildstate->graphCtx = GenerationContextCreate(CurrentMemoryContext, "Hnsw build graph context", #if PG_VERSION_NUM >= 150000 @@ -729,7 +732,6 @@ InitBuildState(HnswBuildState * buildstate, Relation heap, Relation index, Index static void FreeBuildState(HnswBuildState * buildstate) { - pfree(buildstate->normvec); MemoryContextDelete(buildstate->graphCtx); MemoryContextDelete(buildstate->tmpCtx); } diff --git a/src/hnswinsert.c b/src/hnswinsert.c index c3c2885..0e09cfa 100644 --- a/src/hnswinsert.c +++ b/src/hnswinsert.c @@ -622,7 +622,7 @@ HnswInsertTuple(Relation index, Datum *values, bool *isnull, ItemPointer heap_ti normprocinfo = HnswOptionalProcInfo(index, HNSW_NORM_PROC); if (normprocinfo != NULL) { - if (!HnswNormValue(normprocinfo, collation, &value, NULL)) + if (!HnswNormValue(normprocinfo, collation, &value, HnswGetType(index))) return; } diff --git a/src/hnswscan.c b/src/hnswscan.c index eaf0519..59609a4 100644 --- a/src/hnswscan.c +++ b/src/hnswscan.c @@ -1,6 +1,7 @@ #include "postgres.h" #include "access/relscan.h" +#include "halfvec.h" #include "hnsw.h" #include "pgstat.h" #include "storage/bufmgr.h" @@ -73,7 +74,14 @@ GetScanValue(IndexScanDesc scan) Datum value; if (scan->orderByData->sk_flags & SK_ISNULL) - value = PointerGetDatum(InitVector(GetDimensions(scan->indexRelation))); + { + int dimensions = GetDimensions(scan->indexRelation); + + if (HnswGetType(scan->indexRelation) == HNSW_TYPE_HALFVEC) + value = PointerGetDatum(InitHalfVector(dimensions)); + else + value = PointerGetDatum(InitVector(dimensions)); + } else { value = scan->orderByData->sk_argument; @@ -84,7 +92,7 @@ GetScanValue(IndexScanDesc scan) /* Fine if normalization fails */ if (so->normprocinfo != NULL) - HnswNormValue(so->normprocinfo, so->collation, &value, NULL); + HnswNormValue(so->normprocinfo, so->collation, &value, HnswGetType(scan->indexRelation)); } return value; diff --git a/src/hnswutils.c b/src/hnswutils.c index 983fd11..74fd3c3 100644 --- a/src/hnswutils.c +++ b/src/hnswutils.c @@ -3,12 +3,15 @@ #include #include "access/generic_xlog.h" +#include "catalog/pg_type.h" +#include "halfvec.h" #include "hnsw.h" #include "lib/pairingheap.h" #include "storage/bufmgr.h" #include "utils/datum.h" #include "utils/memdebug.h" #include "utils/rel.h" +#include "utils/syscache.h" #include "vector.h" #if PG_VERSION_NUM >= 130000 @@ -149,6 +152,32 @@ HnswOptionalProcInfo(Relation index, uint16 procnum) return index_getprocinfo(index, 1, procnum); } +/* + * Get type + */ +int +HnswGetType(Relation index) +{ + Oid typeOid = TupleDescAttr(index->rd_att, 0)->atttypid; + HeapTuple tuple; + Form_pg_type type; + int result; + + tuple = SearchSysCache1(TYPEOID, ObjectIdGetDatum(typeOid)); + if (!HeapTupleIsValid(tuple)) + elog(ERROR, "cache lookup failed for type %u", typeOid); + + type = (Form_pg_type) GETSTRUCT(tuple); + if (strcmp(NameStr(type->typname), "halfvec") == 0) + result = HNSW_TYPE_HALFVEC; + else + result = HNSW_TYPE_VECTOR; + + ReleaseSysCache(tuple); + + return result; +} + /* * Divide by the norm * @@ -158,21 +187,34 @@ HnswOptionalProcInfo(Relation index, uint16 procnum) * if it's different than the original value */ bool -HnswNormValue(FmgrInfo *procinfo, Oid collation, Datum *value, Vector * result) +HnswNormValue(FmgrInfo *procinfo, Oid collation, Datum *value, int type) { double norm = DatumGetFloat8(FunctionCall1Coll(procinfo, collation, *value)); if (norm > 0) { - Vector *v = DatumGetVector(*value); + if (type == HNSW_TYPE_HALFVEC) + { + HalfVector *v = DatumGetHalfVector(*value); + HalfVector *result = InitHalfVector(v->dim); - if (result == NULL) - result = InitVector(v->dim); + for (int i = 0; i < v->dim; i++) + result->x[i] = v->x[i] / norm; - for (int i = 0; i < v->dim; i++) - result->x[i] = v->x[i] / norm; + *value = PointerGetDatum(result); + } + else if (type == HNSW_TYPE_VECTOR) + { + Vector *v = DatumGetVector(*value); + Vector *result = InitVector(v->dim); - *value = PointerGetDatum(result); + for (int i = 0; i < v->dim; i++) + result->x[i] = v->x[i] / norm; + + *value = PointerGetDatum(result); + } + else + elog(ERROR, "Unsupported type"); return true; } diff --git a/test/expected/copy.out b/test/expected/copy.out index 36d4620..79657cc 100644 --- a/test/expected/copy.out +++ b/test/expected/copy.out @@ -1,15 +1,15 @@ -CREATE TABLE t (val vector(3)); -INSERT INTO t (val) VALUES ('[0,0,0]'), ('[1,2,3]'), ('[1,1,1]'), (NULL); -CREATE TABLE t2 (val vector(3)); +CREATE TABLE t (val vector(3), val2 halfvec(3)); +INSERT INTO t (val, val2) VALUES ('[0,0,0]', '[0,0,0]'), ('[1,2,3]', '[1,2,3]'), ('[1,1,1]', '[1,1,1]'), (NULL, NULL); +CREATE TABLE t2 (val vector(3), val2 halfvec(3)); \copy t TO 'results/data.bin' WITH (FORMAT binary) \copy t2 FROM 'results/data.bin' WITH (FORMAT binary) SELECT * FROM t2 ORDER BY val; - val ---------- - [0,0,0] - [1,1,1] - [1,2,3] - + val | val2 +---------+--------- + [0,0,0] | [0,0,0] + [1,1,1] | [1,1,1] + [1,2,3] | [1,2,3] + | (4 rows) DROP TABLE t; diff --git a/test/expected/functions.out b/test/expected/functions.out index 85d1a2f..12f8f6d 100644 --- a/test/expected/functions.out +++ b/test/expected/functions.out @@ -104,105 +104,105 @@ SELECT vector_norm('[3e37,4e37]')::real; 5e+37 (1 row) -SELECT l2_distance('[0,0]', '[3,4]'); +SELECT l2_distance('[0,0]'::vector, '[3,4]'); l2_distance ------------- 5 (1 row) -SELECT l2_distance('[0,0]', '[0,1]'); +SELECT l2_distance('[0,0]'::vector, '[0,1]'); l2_distance ------------- 1 (1 row) -SELECT l2_distance('[1,2]', '[3]'); +SELECT l2_distance('[1,2]'::vector, '[3]'); ERROR: different vector dimensions 2 and 1 -SELECT l2_distance('[3e38]', '[-3e38]'); +SELECT l2_distance('[3e38]'::vector, '[-3e38]'); l2_distance ------------- Infinity (1 row) -SELECT inner_product('[1,2]', '[3,4]'); +SELECT inner_product('[1,2]'::vector, '[3,4]'); inner_product --------------- 11 (1 row) -SELECT inner_product('[1,2]', '[3]'); +SELECT inner_product('[1,2]'::vector, '[3]'); ERROR: different vector dimensions 2 and 1 -SELECT inner_product('[3e38]', '[3e38]'); +SELECT inner_product('[3e38]'::vector, '[3e38]'); inner_product --------------- Infinity (1 row) -SELECT cosine_distance('[1,2]', '[2,4]'); +SELECT cosine_distance('[1,2]'::vector, '[2,4]'); cosine_distance ----------------- 0 (1 row) -SELECT cosine_distance('[1,2]', '[0,0]'); +SELECT cosine_distance('[1,2]'::vector, '[0,0]'); cosine_distance ----------------- NaN (1 row) -SELECT cosine_distance('[1,1]', '[1,1]'); +SELECT cosine_distance('[1,1]'::vector, '[1,1]'); cosine_distance ----------------- 0 (1 row) -SELECT cosine_distance('[1,0]', '[0,2]'); +SELECT cosine_distance('[1,0]'::vector, '[0,2]'); cosine_distance ----------------- 1 (1 row) -SELECT cosine_distance('[1,1]', '[-1,-1]'); +SELECT cosine_distance('[1,1]'::vector, '[-1,-1]'); cosine_distance ----------------- 2 (1 row) -SELECT cosine_distance('[1,2]', '[3]'); +SELECT cosine_distance('[1,2]'::vector, '[3]'); ERROR: different vector dimensions 2 and 1 -SELECT cosine_distance('[1,1]', '[1.1,1.1]'); +SELECT cosine_distance('[1,1]'::vector, '[1.1,1.1]'); cosine_distance ----------------- 0 (1 row) -SELECT cosine_distance('[1,1]', '[-1.1,-1.1]'); +SELECT cosine_distance('[1,1]'::vector, '[-1.1,-1.1]'); cosine_distance ----------------- 2 (1 row) -SELECT cosine_distance('[3e38]', '[3e38]'); +SELECT cosine_distance('[3e38]'::vector, '[3e38]'); cosine_distance ----------------- NaN (1 row) -SELECT l1_distance('[0,0]', '[3,4]'); +SELECT l1_distance('[0,0]'::vector, '[3,4]'); l1_distance ------------- 7 (1 row) -SELECT l1_distance('[0,0]', '[0,1]'); +SELECT l1_distance('[0,0]'::vector, '[0,1]'); l1_distance ------------- 1 (1 row) -SELECT l1_distance('[1,2]', '[3]'); +SELECT l1_distance('[1,2]'::vector, '[3]'); ERROR: different vector dimensions 2 and 1 -SELECT l1_distance('[3e38]', '[-3e38]'); +SELECT l1_distance('[3e38]'::vector, '[-3e38]'); l1_distance ------------- Infinity diff --git a/test/expected/halfvec_functions.out b/test/expected/halfvec_functions.out new file mode 100644 index 0000000..b0bd832 --- /dev/null +++ b/test/expected/halfvec_functions.out @@ -0,0 +1,104 @@ +SELECT l2_distance('[0,0]'::halfvec, '[3,4]'); + l2_distance +------------- + 5 +(1 row) + +SELECT l2_distance('[0,0]'::halfvec, '[0,1]'); + l2_distance +------------- + 1 +(1 row) + +SELECT l2_distance('[1,2]'::halfvec, '[3]'); +ERROR: different halfvec dimensions 2 and 1 +SELECT '[0,0]'::halfvec <-> '[3,4]'; + ?column? +---------- + 5 +(1 row) + +SELECT inner_product('[1,2]'::halfvec, '[3,4]'); + inner_product +--------------- + 11 +(1 row) + +SELECT inner_product('[1,2]'::halfvec, '[3]'); +ERROR: different halfvec dimensions 2 and 1 +SELECT inner_product('[65504]'::halfvec, '[65504]'); + inner_product +--------------- + 4290774016 +(1 row) + +SELECT '[1,2]'::halfvec <#> '[3,4]'; + ?column? +---------- + -11 +(1 row) + +SELECT cosine_distance('[1,2]'::halfvec, '[2,4]'); + cosine_distance +----------------- + 0 +(1 row) + +SELECT cosine_distance('[1,2]'::halfvec, '[0,0]'); + cosine_distance +----------------- + NaN +(1 row) + +SELECT cosine_distance('[1,1]'::halfvec, '[1,1]'); + cosine_distance +----------------- + 0 +(1 row) + +SELECT cosine_distance('[1,0]'::halfvec, '[0,2]'); + cosine_distance +----------------- + 1 +(1 row) + +SELECT cosine_distance('[1,1]'::halfvec, '[-1,-1]'); + cosine_distance +----------------- + 2 +(1 row) + +SELECT cosine_distance('[1,2]'::halfvec, '[3]'); +ERROR: different halfvec dimensions 2 and 1 +SELECT cosine_distance('[1,1]'::halfvec, '[1.1,1.1]'); + cosine_distance +----------------- + 0 +(1 row) + +SELECT cosine_distance('[1,1]'::halfvec, '[-1.1,-1.1]'); + cosine_distance +----------------- + 2 +(1 row) + +SELECT '[1,2]'::halfvec <=> '[2,4]'; + ?column? +---------- + 0 +(1 row) + +SELECT l1_distance('[0,0]'::halfvec, '[3,4]'); + l1_distance +------------- + 7 +(1 row) + +SELECT l1_distance('[0,0]'::halfvec, '[0,1]'); + l1_distance +------------- + 1 +(1 row) + +SELECT l1_distance('[1,2]'::halfvec, '[3]'); +ERROR: different halfvec dimensions 2 and 1 diff --git a/test/expected/halfvec_input.out b/test/expected/halfvec_input.out new file mode 100644 index 0000000..3238005 --- /dev/null +++ b/test/expected/halfvec_input.out @@ -0,0 +1,145 @@ +SELECT '[1,2,3]'::halfvec; + halfvec +--------- + [1,2,3] +(1 row) + +SELECT '[-1,-2,-3]'::halfvec; + halfvec +------------ + [-1,-2,-3] +(1 row) + +SELECT '[1.,2.,3.]'::halfvec; + halfvec +--------- + [1,2,3] +(1 row) + +SELECT ' [ 1, 2 , 3 ] '::halfvec; + halfvec +--------- + [1,2,3] +(1 row) + +SELECT '[1.23456]'::halfvec; + halfvec +------------ + [1.234375] +(1 row) + +SELECT '[hello,1]'::halfvec; +ERROR: invalid input syntax for type halfvec: "[hello,1]" +LINE 1: SELECT '[hello,1]'::halfvec; + ^ +SELECT '[NaN,1]'::halfvec; +ERROR: NaN not allowed in halfvec +LINE 1: SELECT '[NaN,1]'::halfvec; + ^ +SELECT '[Infinity,1]'::halfvec; +ERROR: infinite value not allowed in halfvec +LINE 1: SELECT '[Infinity,1]'::halfvec; + ^ +SELECT '[-Infinity,1]'::halfvec; +ERROR: infinite value not allowed in halfvec +LINE 1: SELECT '[-Infinity,1]'::halfvec; + ^ +SELECT '[65519,-65519]'::halfvec; + halfvec +---------------- + [65504,-65504] +(1 row) + +SELECT '[65520,-65520]'::halfvec; +ERROR: value out of range: overflow +LINE 1: SELECT '[65520,-65520]'::halfvec; + ^ +SELECT '[1e-8,-1e-8]'::halfvec; +ERROR: value out of range: underflow +LINE 1: SELECT '[1e-8,-1e-8]'::halfvec; + ^ +SELECT '[4e38,1]'::halfvec; +ERROR: infinite value not allowed in halfvec +LINE 1: SELECT '[4e38,1]'::halfvec; + ^ +SELECT '[1,2,3'::halfvec; +ERROR: malformed halfvec literal: "[1,2,3" +LINE 1: SELECT '[1,2,3'::halfvec; + ^ +DETAIL: Unexpected end of input. +SELECT '[1,2,3]9'::halfvec; +ERROR: malformed halfvec literal: "[1,2,3]9" +LINE 1: SELECT '[1,2,3]9'::halfvec; + ^ +DETAIL: Junk after closing right brace. +SELECT '1,2,3'::halfvec; +ERROR: malformed halfvec literal: "1,2,3" +LINE 1: SELECT '1,2,3'::halfvec; + ^ +DETAIL: Vector contents must start with "[". +SELECT ''::halfvec; +ERROR: malformed halfvec literal: "" +LINE 1: SELECT ''::halfvec; + ^ +DETAIL: Vector contents must start with "[". +SELECT '['::halfvec; +ERROR: malformed halfvec literal: "[" +LINE 1: SELECT '['::halfvec; + ^ +DETAIL: Unexpected end of input. +SELECT '[,'::halfvec; +ERROR: malformed halfvec literal: "[," +LINE 1: SELECT '[,'::halfvec; + ^ +DETAIL: Unexpected end of input. +SELECT '[]'::halfvec; +ERROR: halfvec must have at least 1 dimension +LINE 1: SELECT '[]'::halfvec; + ^ +SELECT '[1,]'::halfvec; +ERROR: invalid input syntax for type halfvec: "[1,]" +LINE 1: SELECT '[1,]'::halfvec; + ^ +SELECT '[1a]'::halfvec; +ERROR: invalid input syntax for type halfvec: "[1a]" +LINE 1: SELECT '[1a]'::halfvec; + ^ +SELECT '[1,,3]'::halfvec; +ERROR: malformed halfvec literal: "[1,,3]" +LINE 1: SELECT '[1,,3]'::halfvec; + ^ +SELECT '[1, ,3]'::halfvec; +ERROR: invalid input syntax for type halfvec: "[1, ,3]" +LINE 1: SELECT '[1, ,3]'::halfvec; + ^ +SELECT '[1,2,3]'::halfvec(3); + halfvec +--------- + [1,2,3] +(1 row) + +SELECT '[1,2,3]'::halfvec(2); +ERROR: expected 2 dimensions, not 3 +SELECT '[1,2,3]'::halfvec(3, 2); +ERROR: invalid type modifier +LINE 1: SELECT '[1,2,3]'::halfvec(3, 2); + ^ +SELECT '[1,2,3]'::halfvec('a'); +ERROR: invalid input syntax for type integer: "a" +LINE 1: SELECT '[1,2,3]'::halfvec('a'); + ^ +SELECT '[1,2,3]'::halfvec(0); +ERROR: dimensions for type halfvec must be at least 1 +LINE 1: SELECT '[1,2,3]'::halfvec(0); + ^ +SELECT '[1,2,3]'::halfvec(16001); +ERROR: expected 16001 dimensions, not 3 +SELECT unnest('{"[1,2,3]", "[4,5,6]"}'::halfvec[]); + unnest +--------- + [1,2,3] + [4,5,6] +(2 rows) + +SELECT '{"[1,2,3]"}'::halfvec(2)[]; +ERROR: expected 2 dimensions, not 3 diff --git a/test/expected/hnsw_halfvec_cosine.out b/test/expected/hnsw_halfvec_cosine.out new file mode 100644 index 0000000..6ccca49 --- /dev/null +++ b/test/expected/hnsw_halfvec_cosine.out @@ -0,0 +1,26 @@ +SET enable_seqscan = off; +CREATE TABLE t (val halfvec(3)); +INSERT INTO t (val) VALUES ('[0,0,0]'), ('[1,2,3]'), ('[1,1,1]'), (NULL); +CREATE INDEX ON t USING hnsw (val halfvec_cosine_ops); +INSERT INTO t (val) VALUES ('[1,2,4]'); +SELECT * FROM t ORDER BY val <=> '[3,3,3]'; + val +--------- + [1,1,1] + [1,2,3] + [1,2,4] +(3 rows) + +SELECT COUNT(*) FROM (SELECT * FROM t ORDER BY val <=> '[0,0,0]') t2; + count +------- + 3 +(1 row) + +SELECT COUNT(*) FROM (SELECT * FROM t ORDER BY val <=> (SELECT NULL::halfvec)) t2; + count +------- + 3 +(1 row) + +DROP TABLE t; diff --git a/test/expected/hnsw_halfvec_ip.out b/test/expected/hnsw_halfvec_ip.out new file mode 100644 index 0000000..7c004c9 --- /dev/null +++ b/test/expected/hnsw_halfvec_ip.out @@ -0,0 +1,21 @@ +SET enable_seqscan = off; +CREATE TABLE t (val halfvec(3)); +INSERT INTO t (val) VALUES ('[0,0,0]'), ('[1,2,3]'), ('[1,1,1]'), (NULL); +CREATE INDEX ON t USING hnsw (val halfvec_ip_ops); +INSERT INTO t (val) VALUES ('[1,2,4]'); +SELECT * FROM t ORDER BY val <#> '[3,3,3]'; + val +--------- + [1,2,4] + [1,2,3] + [1,1,1] + [0,0,0] +(4 rows) + +SELECT COUNT(*) FROM (SELECT * FROM t ORDER BY val <#> (SELECT NULL::halfvec)) t2; + count +------- + 4 +(1 row) + +DROP TABLE t; diff --git a/test/expected/hnsw_halfvec_l2.out b/test/expected/hnsw_halfvec_l2.out new file mode 100644 index 0000000..a5ab825 --- /dev/null +++ b/test/expected/hnsw_halfvec_l2.out @@ -0,0 +1,36 @@ +SET enable_seqscan = off; +CREATE TABLE t (val halfvec(3)); +INSERT INTO t (val) VALUES ('[0,0,0]'), ('[1,2,3]'), ('[1,1,1]'), (NULL); +CREATE INDEX ON t USING hnsw (val halfvec_l2_ops); +INSERT INTO t (val) VALUES ('[1,2,4]'); +SELECT * FROM t ORDER BY val <-> '[3,3,3]'; + val +--------- + [1,2,3] + [1,2,4] + [1,1,1] + [0,0,0] +(4 rows) + +SELECT * FROM t ORDER BY val <-> (SELECT NULL::halfvec); + val +--------- + [0,0,0] + [1,1,1] + [1,2,3] + [1,2,4] +(4 rows) + +SELECT COUNT(*) FROM t; + count +------- + 5 +(1 row) + +TRUNCATE t; +SELECT * FROM t ORDER BY val <-> '[3,3,3]'; + val +----- +(0 rows) + +DROP TABLE t; diff --git a/test/sql/copy.sql b/test/sql/copy.sql index 2820090..ee93a50 100644 --- a/test/sql/copy.sql +++ b/test/sql/copy.sql @@ -1,7 +1,7 @@ -CREATE TABLE t (val vector(3)); -INSERT INTO t (val) VALUES ('[0,0,0]'), ('[1,2,3]'), ('[1,1,1]'), (NULL); +CREATE TABLE t (val vector(3), val2 halfvec(3)); +INSERT INTO t (val, val2) VALUES ('[0,0,0]', '[0,0,0]'), ('[1,2,3]', '[1,2,3]'), ('[1,1,1]', '[1,1,1]'), (NULL, NULL); -CREATE TABLE t2 (val vector(3)); +CREATE TABLE t2 (val vector(3), val2 halfvec(3)); \copy t TO 'results/data.bin' WITH (FORMAT binary) \copy t2 FROM 'results/data.bin' WITH (FORMAT binary) diff --git a/test/sql/functions.sql b/test/sql/functions.sql index 6235684..7e820d7 100644 --- a/test/sql/functions.sql +++ b/test/sql/functions.sql @@ -24,29 +24,29 @@ SELECT vector_norm('[3,4]'); SELECT vector_norm('[0,1]'); SELECT vector_norm('[3e37,4e37]')::real; -SELECT l2_distance('[0,0]', '[3,4]'); -SELECT l2_distance('[0,0]', '[0,1]'); -SELECT l2_distance('[1,2]', '[3]'); -SELECT l2_distance('[3e38]', '[-3e38]'); +SELECT l2_distance('[0,0]'::vector, '[3,4]'); +SELECT l2_distance('[0,0]'::vector, '[0,1]'); +SELECT l2_distance('[1,2]'::vector, '[3]'); +SELECT l2_distance('[3e38]'::vector, '[-3e38]'); -SELECT inner_product('[1,2]', '[3,4]'); -SELECT inner_product('[1,2]', '[3]'); -SELECT inner_product('[3e38]', '[3e38]'); +SELECT inner_product('[1,2]'::vector, '[3,4]'); +SELECT inner_product('[1,2]'::vector, '[3]'); +SELECT inner_product('[3e38]'::vector, '[3e38]'); -SELECT cosine_distance('[1,2]', '[2,4]'); -SELECT cosine_distance('[1,2]', '[0,0]'); -SELECT cosine_distance('[1,1]', '[1,1]'); -SELECT cosine_distance('[1,0]', '[0,2]'); -SELECT cosine_distance('[1,1]', '[-1,-1]'); -SELECT cosine_distance('[1,2]', '[3]'); -SELECT cosine_distance('[1,1]', '[1.1,1.1]'); -SELECT cosine_distance('[1,1]', '[-1.1,-1.1]'); -SELECT cosine_distance('[3e38]', '[3e38]'); +SELECT cosine_distance('[1,2]'::vector, '[2,4]'); +SELECT cosine_distance('[1,2]'::vector, '[0,0]'); +SELECT cosine_distance('[1,1]'::vector, '[1,1]'); +SELECT cosine_distance('[1,0]'::vector, '[0,2]'); +SELECT cosine_distance('[1,1]'::vector, '[-1,-1]'); +SELECT cosine_distance('[1,2]'::vector, '[3]'); +SELECT cosine_distance('[1,1]'::vector, '[1.1,1.1]'); +SELECT cosine_distance('[1,1]'::vector, '[-1.1,-1.1]'); +SELECT cosine_distance('[3e38]'::vector, '[3e38]'); -SELECT l1_distance('[0,0]', '[3,4]'); -SELECT l1_distance('[0,0]', '[0,1]'); -SELECT l1_distance('[1,2]', '[3]'); -SELECT l1_distance('[3e38]', '[-3e38]'); +SELECT l1_distance('[0,0]'::vector, '[3,4]'); +SELECT l1_distance('[0,0]'::vector, '[0,1]'); +SELECT l1_distance('[1,2]'::vector, '[3]'); +SELECT l1_distance('[3e38]'::vector, '[-3e38]'); SELECT avg(v) FROM unnest(ARRAY['[1,2,3]'::vector, '[3,5,7]']) v; SELECT avg(v) FROM unnest(ARRAY['[1,2,3]'::vector, '[3,5,7]', NULL]) v; diff --git a/test/sql/halfvec_functions.sql b/test/sql/halfvec_functions.sql new file mode 100644 index 0000000..17f28a5 --- /dev/null +++ b/test/sql/halfvec_functions.sql @@ -0,0 +1,23 @@ +SELECT l2_distance('[0,0]'::halfvec, '[3,4]'); +SELECT l2_distance('[0,0]'::halfvec, '[0,1]'); +SELECT l2_distance('[1,2]'::halfvec, '[3]'); +SELECT '[0,0]'::halfvec <-> '[3,4]'; + +SELECT inner_product('[1,2]'::halfvec, '[3,4]'); +SELECT inner_product('[1,2]'::halfvec, '[3]'); +SELECT inner_product('[65504]'::halfvec, '[65504]'); +SELECT '[1,2]'::halfvec <#> '[3,4]'; + +SELECT cosine_distance('[1,2]'::halfvec, '[2,4]'); +SELECT cosine_distance('[1,2]'::halfvec, '[0,0]'); +SELECT cosine_distance('[1,1]'::halfvec, '[1,1]'); +SELECT cosine_distance('[1,0]'::halfvec, '[0,2]'); +SELECT cosine_distance('[1,1]'::halfvec, '[-1,-1]'); +SELECT cosine_distance('[1,2]'::halfvec, '[3]'); +SELECT cosine_distance('[1,1]'::halfvec, '[1.1,1.1]'); +SELECT cosine_distance('[1,1]'::halfvec, '[-1.1,-1.1]'); +SELECT '[1,2]'::halfvec <=> '[2,4]'; + +SELECT l1_distance('[0,0]'::halfvec, '[3,4]'); +SELECT l1_distance('[0,0]'::halfvec, '[0,1]'); +SELECT l1_distance('[1,2]'::halfvec, '[3]'); diff --git a/test/sql/halfvec_input.sql b/test/sql/halfvec_input.sql new file mode 100644 index 0000000..1ae3abd --- /dev/null +++ b/test/sql/halfvec_input.sql @@ -0,0 +1,34 @@ +SELECT '[1,2,3]'::halfvec; +SELECT '[-1,-2,-3]'::halfvec; +SELECT '[1.,2.,3.]'::halfvec; +SELECT ' [ 1, 2 , 3 ] '::halfvec; +SELECT '[1.23456]'::halfvec; +SELECT '[hello,1]'::halfvec; +SELECT '[NaN,1]'::halfvec; +SELECT '[Infinity,1]'::halfvec; +SELECT '[-Infinity,1]'::halfvec; +SELECT '[65519,-65519]'::halfvec; +SELECT '[65520,-65520]'::halfvec; +SELECT '[1e-8,-1e-8]'::halfvec; +SELECT '[4e38,1]'::halfvec; +SELECT '[1,2,3'::halfvec; +SELECT '[1,2,3]9'::halfvec; +SELECT '1,2,3'::halfvec; +SELECT ''::halfvec; +SELECT '['::halfvec; +SELECT '[,'::halfvec; +SELECT '[]'::halfvec; +SELECT '[1,]'::halfvec; +SELECT '[1a]'::halfvec; +SELECT '[1,,3]'::halfvec; +SELECT '[1, ,3]'::halfvec; + +SELECT '[1,2,3]'::halfvec(3); +SELECT '[1,2,3]'::halfvec(2); +SELECT '[1,2,3]'::halfvec(3, 2); +SELECT '[1,2,3]'::halfvec('a'); +SELECT '[1,2,3]'::halfvec(0); +SELECT '[1,2,3]'::halfvec(16001); + +SELECT unnest('{"[1,2,3]", "[4,5,6]"}'::halfvec[]); +SELECT '{"[1,2,3]"}'::halfvec(2)[]; diff --git a/test/sql/hnsw_halfvec_cosine.sql b/test/sql/hnsw_halfvec_cosine.sql new file mode 100644 index 0000000..e5473b4 --- /dev/null +++ b/test/sql/hnsw_halfvec_cosine.sql @@ -0,0 +1,13 @@ +SET enable_seqscan = off; + +CREATE TABLE t (val halfvec(3)); +INSERT INTO t (val) VALUES ('[0,0,0]'), ('[1,2,3]'), ('[1,1,1]'), (NULL); +CREATE INDEX ON t USING hnsw (val halfvec_cosine_ops); + +INSERT INTO t (val) VALUES ('[1,2,4]'); + +SELECT * FROM t ORDER BY val <=> '[3,3,3]'; +SELECT COUNT(*) FROM (SELECT * FROM t ORDER BY val <=> '[0,0,0]') t2; +SELECT COUNT(*) FROM (SELECT * FROM t ORDER BY val <=> (SELECT NULL::halfvec)) t2; + +DROP TABLE t; diff --git a/test/sql/hnsw_halfvec_ip.sql b/test/sql/hnsw_halfvec_ip.sql new file mode 100644 index 0000000..bdf32f7 --- /dev/null +++ b/test/sql/hnsw_halfvec_ip.sql @@ -0,0 +1,12 @@ +SET enable_seqscan = off; + +CREATE TABLE t (val halfvec(3)); +INSERT INTO t (val) VALUES ('[0,0,0]'), ('[1,2,3]'), ('[1,1,1]'), (NULL); +CREATE INDEX ON t USING hnsw (val halfvec_ip_ops); + +INSERT INTO t (val) VALUES ('[1,2,4]'); + +SELECT * FROM t ORDER BY val <#> '[3,3,3]'; +SELECT COUNT(*) FROM (SELECT * FROM t ORDER BY val <#> (SELECT NULL::halfvec)) t2; + +DROP TABLE t; diff --git a/test/sql/hnsw_halfvec_l2.sql b/test/sql/hnsw_halfvec_l2.sql new file mode 100644 index 0000000..f754ad3 --- /dev/null +++ b/test/sql/hnsw_halfvec_l2.sql @@ -0,0 +1,16 @@ +SET enable_seqscan = off; + +CREATE TABLE t (val halfvec(3)); +INSERT INTO t (val) VALUES ('[0,0,0]'), ('[1,2,3]'), ('[1,1,1]'), (NULL); +CREATE INDEX ON t USING hnsw (val halfvec_l2_ops); + +INSERT INTO t (val) VALUES ('[1,2,4]'); + +SELECT * FROM t ORDER BY val <-> '[3,3,3]'; +SELECT * FROM t ORDER BY val <-> (SELECT NULL::halfvec); +SELECT COUNT(*) FROM t; + +TRUNCATE t; +SELECT * FROM t ORDER BY val <-> '[3,3,3]'; + +DROP TABLE t; diff --git a/test/t/020_hnsw_halfvec_build_recall.pl b/test/t/020_hnsw_halfvec_build_recall.pl new file mode 100644 index 0000000..35771cb --- /dev/null +++ b/test/t/020_hnsw_halfvec_build_recall.pl @@ -0,0 +1,128 @@ +use strict; +use warnings; +use PostgresNode; +use TestLib; +use Test::More; + +my $node; +my @queries = (); +my @expected; +my $limit = 20; + +sub test_recall +{ + my ($min, $operator) = @_; + my $correct = 0; + my $total = 0; + + my $explain = $node->safe_psql("postgres", qq( + SET enable_seqscan = off; + EXPLAIN ANALYZE SELECT i FROM tst ORDER BY v $operator '$queries[0]' LIMIT $limit; + )); + like($explain, qr/Index Scan/); + + for my $i (0 .. $#queries) + { + my $actual = $node->safe_psql("postgres", qq( + SET enable_seqscan = off; + SELECT i FROM tst ORDER BY v $operator '$queries[$i]' LIMIT $limit; + )); + my @actual_ids = split("\n", $actual); + my %actual_set = map { $_ => 1 } @actual_ids; + + my @expected_ids = split("\n", $expected[$i]); + + foreach (@expected_ids) + { + if (exists($actual_set{$_})) + { + $correct++; + } + $total++; + } + } + + cmp_ok($correct / $total, ">=", $min, $operator); +} + +# Initialize node +$node = get_new_node('node'); +$node->init; +$node->start; + +# Create table +$node->safe_psql("postgres", "CREATE EXTENSION vector;"); +$node->safe_psql("postgres", "CREATE TABLE tst (i int4, v halfvec(3));"); +$node->safe_psql("postgres", + "INSERT INTO tst SELECT i, ARRAY[random(), random(), random()] FROM generate_series(1, 10000) i;" +); + +# Generate queries +for (1 .. 20) +{ + my $r1 = rand(); + my $r2 = rand(); + my $r3 = rand(); + push(@queries, "[$r1,$r2,$r3]"); +} + +# Check each index type +my @operators = ("<->", "<#>"); #, "<=>"); +my @opclasses = ("halfvec_l2_ops", "halfvec_ip_ops"); #, "halfvec_cosine_ops"); + +for my $i (0 .. $#operators) +{ + my $operator = $operators[$i]; + my $opclass = $opclasses[$i]; + + # Get exact results + @expected = (); + foreach (@queries) + { + my $res = $node->safe_psql("postgres", "SELECT i FROM tst ORDER BY v $operator '$_' LIMIT $limit;"); + push(@expected, $res); + } + + # Build index serially + $node->safe_psql("postgres", qq( + SET max_parallel_maintenance_workers = 0; + CREATE INDEX idx ON tst USING hnsw (v $opclass); + )); + + # Test approximate results + my $min = $operator eq "<#>" ? 0.80 : 0.99; + test_recall($min, $operator); + + $node->safe_psql("postgres", "DROP INDEX idx;"); + + # Build index in parallel in memory + my ($ret, $stdout, $stderr) = $node->psql("postgres", qq( + SET client_min_messages = DEBUG; + SET min_parallel_table_scan_size = 1; + CREATE INDEX idx ON tst USING hnsw (v $opclass); + )); + is($ret, 0, $stderr); + like($stderr, qr/using \d+ parallel workers/); + + # Test approximate results + test_recall($min, $operator); + + $node->safe_psql("postgres", "DROP INDEX idx;"); + + # Build index in parallel on disk + # Set parallel_workers on table to use workers with low maintenance_work_mem + ($ret, $stdout, $stderr) = $node->psql("postgres", qq( + ALTER TABLE tst SET (parallel_workers = 2); + SET client_min_messages = DEBUG; + SET maintenance_work_mem = '4MB'; + CREATE INDEX idx ON tst USING hnsw (v $opclass); + ALTER TABLE tst RESET (parallel_workers); + )); + is($ret, 0, $stderr); + like($stderr, qr/using \d+ parallel workers/); + like($stderr, qr/hnsw graph no longer fits into maintenance_work_mem/); + + $node->safe_psql("postgres", "DROP INDEX idx;"); +} + +done_testing();