From 422667f6c6d02648282a9d141309f2dad62b12f7 Mon Sep 17 00:00:00 2001 From: Andrew Kane Date: Sun, 3 Dec 2023 13:01:47 -0800 Subject: [PATCH] Added half type --- .github/workflows/build.yml | 2 - Makefile | 4 +- Makefile.win | 4 +- sql/vector--0.5.1--0.6.0.sql | 77 +++++ sql/vector.sql | 87 +++++ src/half.c | 599 +++++++++++++++++++++++++++++++++++ src/half.h | 28 ++ test/expected/copy.out | 18 +- test/expected/functions.out | 40 +-- test/expected/half.out | 182 +++++++++++ test/sql/copy.sql | 6 +- test/sql/functions.sql | 40 +-- test/sql/half.sql | 42 +++ 13 files changed, 1071 insertions(+), 58 deletions(-) create mode 100644 sql/vector--0.5.1--0.6.0.sql create mode 100644 src/half.c create mode 100644 src/half.h create mode 100644 test/expected/half.out create mode 100644 test/sql/half.sql diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 249400d..dc8890e 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -20,8 +20,6 @@ jobs: os: ubuntu-20.04 - postgres: 12 os: ubuntu-20.04 - - postgres: 11 - os: ubuntu-20.04 steps: - uses: actions/checkout@v4 - uses: ankane/setup-postgres@v1 diff --git a/Makefile b/Makefile index f6c1f20..197867a 100644 --- a/Makefile +++ b/Makefile @@ -3,8 +3,8 @@ EXTVERSION = 0.5.1 MODULE_big = vector DATA = $(wildcard sql/*--*.sql) -OBJS = src/hnsw.o src/hnswbuild.o src/hnswinsert.o src/hnswscan.o src/hnswutils.o src/hnswvacuum.o src/ivfbuild.o src/ivfflat.o src/ivfinsert.o src/ivfkmeans.o src/ivfscan.o src/ivfutils.o src/ivfvacuum.o src/vector.o -HEADERS = src/vector.h +OBJS = src/half.o src/hnsw.o src/hnswbuild.o src/hnswinsert.o src/hnswscan.o src/hnswutils.o src/hnswvacuum.o src/ivfbuild.o src/ivfflat.o src/ivfinsert.o src/ivfkmeans.o src/ivfscan.o src/ivfutils.o src/ivfvacuum.o src/vector.o +HEADERS = src/half.h src/vector.h TESTS = $(wildcard test/sql/*.sql) REGRESS = $(patsubst test/sql/%.sql,%,$(TESTS)) diff --git a/Makefile.win b/Makefile.win index f6d955a..089200a 100644 --- a/Makefile.win +++ b/Makefile.win @@ -1,8 +1,8 @@ EXTENSION = vector EXTVERSION = 0.5.1 -OBJS = src\hnsw.obj src\hnswbuild.obj src\hnswinsert.obj src\hnswscan.obj src\hnswutils.obj src\hnswvacuum.obj src\ivfbuild.obj src\ivfflat.obj src\ivfinsert.obj src\ivfkmeans.obj src\ivfscan.obj src\ivfutils.obj src\ivfvacuum.obj src\vector.obj -HEADERS = src\vector.h +OBJS = src\half.obj src\hnsw.obj src\hnswbuild.obj src\hnswinsert.obj src\hnswscan.obj src\hnswutils.obj src\hnswvacuum.obj src\ivfbuild.obj src\ivfflat.obj src\ivfinsert.obj src\ivfkmeans.obj src\ivfscan.obj src\ivfutils.obj src\ivfvacuum.obj src\vector.obj +HEADERS = src\half.h src\vector.h REGRESS = btree cast copy functions input ivfflat_cosine ivfflat_ip ivfflat_l2 ivfflat_options ivfflat_unlogged REGRESS_OPTS = --inputdir=test --load-extension=$(EXTENSION) diff --git a/sql/vector--0.5.1--0.6.0.sql b/sql/vector--0.5.1--0.6.0.sql new file mode 100644 index 0000000..61348fa --- /dev/null +++ b/sql/vector--0.5.1--0.6.0.sql @@ -0,0 +1,77 @@ +-- complain if script is sourced in psql, rather than via CREATE EXTENSION +\echo Use "ALTER EXTENSION vector UPDATE TO '0.6.0'" to load this file. \quit + +CREATE TYPE half; + +CREATE FUNCTION half_in(cstring, oid, integer) RETURNS half + AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + +CREATE FUNCTION half_out(half) RETURNS cstring + AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + +CREATE FUNCTION half_recv(internal, oid, integer) RETURNS half + AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + +CREATE FUNCTION half_send(half) RETURNS bytea + AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + +CREATE TYPE half ( + INPUT = half_in, + OUTPUT = half_out, + RECEIVE = half_recv, + SEND = half_send, + INTERNALLENGTH = 2, + PASSEDBYVALUE, + ALIGNMENT = int2 +); + +CREATE FUNCTION l2_distance(half[], half[]) RETURNS float8 + AS 'MODULE_PATHNAME', 'half_l2_distance' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + +CREATE FUNCTION inner_product(half[], half[]) RETURNS float8 + AS 'MODULE_PATHNAME', 'half_inner_product' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + +CREATE FUNCTION cosine_distance(half[], half[]) RETURNS float8 + AS 'MODULE_PATHNAME', 'half_cosine_distance' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + +CREATE FUNCTION l1_distance(half[], half[]) RETURNS float8 + AS 'MODULE_PATHNAME', 'half_l1_distance' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + +CREATE FUNCTION half_l2_squared_distance(half[], half[]) RETURNS float8 + AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + +CREATE FUNCTION half_negative_inner_product(half[], half[]) RETURNS float8 + AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + +CREATE FUNCTION float4_to_half(real, integer, boolean) RETURNS half + AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + +CREATE FUNCTION integer_to_half(integer, integer, boolean) RETURNS half + AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + +CREATE FUNCTION numeric_to_half(numeric, integer, boolean) RETURNS half + AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + +CREATE CAST (real AS half) + WITH FUNCTION float4_to_half(real, integer, boolean) AS IMPLICIT; + +CREATE CAST (integer AS half) + WITH FUNCTION integer_to_half(integer, integer, boolean) AS IMPLICIT; + +CREATE CAST (numeric AS half) + WITH FUNCTION numeric_to_half(numeric, integer, boolean) AS IMPLICIT; + +CREATE OPERATOR <-> ( + LEFTARG = half[], RIGHTARG = half[], PROCEDURE = l2_distance, + COMMUTATOR = '<->' +); + +CREATE OPERATOR <#> ( + LEFTARG = half[], RIGHTARG = half[], PROCEDURE = half_negative_inner_product, + COMMUTATOR = '<#>' +); + +CREATE OPERATOR <=> ( + LEFTARG = half[], RIGHTARG = half[], PROCEDURE = cosine_distance, + COMMUTATOR = '<=>' +); diff --git a/sql/vector.sql b/sql/vector.sql index 137931f..c86d6d9 100644 --- a/sql/vector.sql +++ b/sql/vector.sql @@ -290,3 +290,90 @@ CREATE OPERATOR CLASS vector_cosine_ops OPERATOR 1 <=> (vector, vector) FOR ORDER BY float_ops, FUNCTION 1 vector_negative_inner_product(vector, vector), FUNCTION 2 vector_norm(vector); + +-- half type + +CREATE TYPE half; + +CREATE FUNCTION half_in(cstring, oid, integer) RETURNS half + AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + +CREATE FUNCTION half_out(half) RETURNS cstring + AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + +CREATE FUNCTION half_recv(internal, oid, integer) RETURNS half + AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + +CREATE FUNCTION half_send(half) RETURNS bytea + AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + +CREATE TYPE half ( + INPUT = half_in, + OUTPUT = half_out, + RECEIVE = half_recv, + SEND = half_send, + INTERNALLENGTH = 2, + PASSEDBYVALUE, + ALIGNMENT = int2 +); + +-- half functions + +CREATE FUNCTION l2_distance(half[], half[]) RETURNS float8 + AS 'MODULE_PATHNAME', 'half_l2_distance' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + +CREATE FUNCTION inner_product(half[], half[]) RETURNS float8 + AS 'MODULE_PATHNAME', 'half_inner_product' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + +CREATE FUNCTION cosine_distance(half[], half[]) RETURNS float8 + AS 'MODULE_PATHNAME', 'half_cosine_distance' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + +CREATE FUNCTION l1_distance(half[], half[]) RETURNS float8 + AS 'MODULE_PATHNAME', 'half_l1_distance' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + +-- half private functions + +CREATE FUNCTION half_l2_squared_distance(half[], half[]) RETURNS float8 + AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + +CREATE FUNCTION half_negative_inner_product(half[], half[]) RETURNS float8 + AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + +-- half cast functions + +CREATE FUNCTION float4_to_half(real, integer, boolean) RETURNS half + AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + +CREATE FUNCTION integer_to_half(integer, integer, boolean) RETURNS half + AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + +CREATE FUNCTION numeric_to_half(numeric, integer, boolean) RETURNS half + AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + +-- half casts + +CREATE CAST (real AS half) + WITH FUNCTION float4_to_half(real, integer, boolean) AS IMPLICIT; + +CREATE CAST (integer AS half) + WITH FUNCTION integer_to_half(integer, integer, boolean) AS IMPLICIT; + +CREATE CAST (numeric AS half) + WITH FUNCTION numeric_to_half(numeric, integer, boolean) AS IMPLICIT; + +-- half operators + +CREATE OPERATOR <-> ( + LEFTARG = half[], RIGHTARG = half[], PROCEDURE = l2_distance, + COMMUTATOR = '<->' +); + +CREATE OPERATOR <#> ( + LEFTARG = half[], RIGHTARG = half[], PROCEDURE = half_negative_inner_product, + COMMUTATOR = '<#>' +); + +CREATE OPERATOR <=> ( + LEFTARG = half[], RIGHTARG = half[], PROCEDURE = cosine_distance, + COMMUTATOR = '<=>' +); diff --git a/src/half.c b/src/half.c new file mode 100644 index 0000000..96a3100 --- /dev/null +++ b/src/half.c @@ -0,0 +1,599 @@ +#include "postgres.h" + +#include + +#include "common/shortest_dec.h" +#include "fmgr.h" +#include "half.h" +#include "lib/stringinfo.h" +#include "libpq/pqformat.h" +#include "utils/array.h" +#include "utils/builtins.h" +#include "utils/float.h" +#include "utils/numeric.h" + +#if PG_VERSION_NUM < 120003 +static pg_noinline void +float_overflow_error(void) +{ + ereport(ERROR, + (errcode(ERRCODE_NUMERIC_VALUE_OUT_OF_RANGE), + errmsg("value out of range: overflow"))); +} + +static pg_noinline void +float_underflow_error(void) +{ + ereport(ERROR, + (errcode(ERRCODE_NUMERIC_VALUE_OUT_OF_RANGE), + errmsg("value out of range: underflow"))); +} +#endif + +/* + * Check if array is a vector + */ +static void +CheckArrayIsVector(ArrayType *array) +{ + if (ARR_NDIM(array) > 1) + ereport(ERROR, + (errcode(ERRCODE_DATA_EXCEPTION), + errmsg("array must be 1-D"))); + + if (ARR_HASNULL(array) && array_contains_nulls(array)) + ereport(ERROR, + (errcode(ERRCODE_NULL_VALUE_NOT_ALLOWED), + errmsg("array must not contain nulls"))); +} + +/* + * Check if dimensions are the same + */ +static int +CheckDims(ArrayType *a, ArrayType *b) +{ + int dima; + int dimb; + + CheckArrayIsVector(a); + CheckArrayIsVector(b); + + dima = ARR_DIMS(a)[0]; + dimb = ARR_DIMS(b)[0]; + + if (dima != dimb) + ereport(ERROR, + (errcode(ERRCODE_DATA_EXCEPTION), + errmsg("different dimensions %d and %d", dima, dimb))); + + return dima; +} + +/* + * Return the datum representation for a half + */ +static inline Datum +HalfGetDatum(half X) +{ + union + { + half value; + int16 retval; + } myunion; + + myunion.value = X; + return Int16GetDatum(myunion.retval); +} + +/* + * Return the half value of a datum + */ +static inline half +DatumGetHalf(Datum X) +{ + union + { + int16 value; + half retval; + } myunion; + + myunion.value = DatumGetInt16(X); + return myunion.retval; +} + +/* + * Append a half to a StringInfo buffer + */ +static half +pq_getmsghalf(StringInfo msg) +{ + union + { + half h; + uint16 i; + } swap; + + /* TODO Likely use float4 for clients */ + swap.i = pq_getmsgint(msg, 2); + return swap.h; +} + +/* + * Get a half from a message buffer + */ +static void +pq_sendhalf(StringInfo buf, half h) +{ + union + { + half h; + uint16 i; + } swap; + + /* TODO Likely use float4 for clients */ + swap.h = h; + pq_sendint16(buf, swap.i); +} + +/* + * Convert a half to a float4 + */ +static float +HalfToFloat4(half num) +{ +#ifdef FLT16_SUPPORT + return (float) num; +#else + /* TODO Improve performance */ + /* TODO Check endianness */ + uint16 bin = *((uint16 *) &num); + uint32 exponent = (bin & 0x7C00) >> 10; + uint32 mantissa = bin & 0x03FF; + + /* Sign */ + uint32 result = (bin & 0x8000) << 16; + + if (exponent == 31) + { + if (mantissa == 0) + { + /* Infinite */ + result |= 0x7F800000; + } + else + { + /* NaN */ + result |= 0x7FC00000; + result |= mantissa << 13; + } + } + else if (exponent == 0) + { + /* Subnormal */ + if (mantissa != 0) + { + exponent = -14; + + for (int i = 0; i < 10; i++) + { + mantissa <<= 1; + exponent -= 1; + + if ((mantissa >> 10) % 2 == 1) + { + mantissa &= 0x03ff; + break; + } + } + + result |= (exponent + 127) << 23; + result |= mantissa << 13; + } + } + else + { + /* Normal */ + result |= (exponent - 15 + 127) << 23; + result |= mantissa << 13; + } + + return *((float *) &result); +#endif +} + +/* + * Convert a float4 to a half + */ +static half +Float4ToHalfUnchecked(float num) +{ +#ifdef FLT16_SUPPORT + return (_Float16) num; +#else + /* TODO Improve performance */ + /* TODO Check endianness */ + uint32 bin = *((uint32 *) &num); + int exponent = (bin & 0x7F800000) >> 23; + int mantissa = bin & 0x007FFFFF; + + /* Sign */ + uint16 result = (bin & 0x80000000) >> 16; + + if (isinf(num)) + { + /* Infinite */ + result |= 0x7C00; + } + else if (isnan(num)) + { + /* NaN */ + result |= 0x7E00; + result |= mantissa >> 13; + } + else if (exponent > 98) + { + int m; + int gr; + int s; + + exponent -= 127; + s = mantissa & 0x00000FFF; + + /* Subnormal */ + if (exponent < -14) + { + int diff = -exponent - 14; + + mantissa >>= diff; + mantissa += 1 << (23 - diff); + s |= mantissa & 0x00000FFF; + } + + m = mantissa >> 13; + + /* Round */ + gr = (mantissa >> 12) % 4; + if (gr == 3 || (gr == 1 && s != 0)) + m += 1; + + if (m == 1024) + { + m = 0; + exponent += 1; + } + + if (exponent > 15) + { + /* Infinite */ + result |= 0x7C00; + } + else + { + if (exponent >= -14) + result |= (exponent + 15) << 10; + + result |= m; + } + } + + return *((half *) & result); +#endif +} + +/* + * Convert a float4 to a half + */ +static half +Float4ToHalf(float num) +{ + half result = Float4ToHalfUnchecked(num); + + /* TODO Perform checks without HalfToFloat4 */ + if (unlikely(isinf(HalfToFloat4(result))) && !isinf(num)) + float_overflow_error(); + if (unlikely(HalfToFloat4(result) == 0.0f) && num != 0.0) + float_underflow_error(); + + return result; +} + +/* + * Convert textual representation to internal representation + */ +PGDLLEXPORT PG_FUNCTION_INFO_V1(half_in); +Datum +half_in(PG_FUNCTION_ARGS) +{ + char *num = PG_GETARG_CSTRING(0); + char *orig_num; + float val; + char *endptr; + + orig_num = num; + + /* Skip leading whitespace */ + while (*num != '\0' && isspace((unsigned char) *num)) + num++; + + /* + * Check for an empty-string input to begin with, to avoid the vagaries of + * strtof() on different platforms. + */ + if (*num == '\0') + ereport(ERROR, + (errcode(ERRCODE_INVALID_TEXT_REPRESENTATION), + errmsg("invalid input syntax for type %s: \"%s\"", + "half", orig_num))); + + val = strtof(num, &endptr); + + if (val < -HALF_MAX || val > HALF_MAX) + ereport(ERROR, + (errcode(ERRCODE_NUMERIC_VALUE_OUT_OF_RANGE), + errmsg("\"%s\" is out of range for type %s", + orig_num, "half"))); + + /* Skip trailing whitespace */ + while (*endptr != '\0' && isspace((unsigned char) *endptr)) + endptr++; + + /* If there is any junk left at the end of the string, bail out */ + if (*endptr != '\0') + ereport(ERROR, + (errcode(ERRCODE_INVALID_TEXT_REPRESENTATION), + errmsg("invalid input syntax for type %s: \"%s\"", + "half", orig_num))); + + PG_RETURN_HALF(Float4ToHalf(val)); +} + +/* + * Convert internal representation to textual representation + */ +PGDLLEXPORT PG_FUNCTION_INFO_V1(half_out); +Datum +half_out(PG_FUNCTION_ARGS) +{ + float num = HalfToFloat4(PG_GETARG_HALF(0)); + char *ascii = (char *) palloc(32); + int ndig = FLT_DIG + extra_float_digits; + + if (extra_float_digits > 0) + { + float_to_shortest_decimal_buf(num, ascii); + PG_RETURN_CSTRING(ascii); + } + + (void) pg_strfromd(ascii, 32, ndig, num); + PG_RETURN_CSTRING(ascii); +} + +/* + * Convert external binary representation to internal representation + */ +PGDLLEXPORT PG_FUNCTION_INFO_V1(half_recv); +Datum +half_recv(PG_FUNCTION_ARGS) +{ + StringInfo buf = (StringInfo) PG_GETARG_POINTER(0); + + PG_RETURN_HALF(pq_getmsghalf(buf)); +} + +/* + * Convert internal representation to the external binary representation + */ +PGDLLEXPORT PG_FUNCTION_INFO_V1(half_send); +Datum +half_send(PG_FUNCTION_ARGS) +{ + half arg1 = PG_GETARG_HALF(0); + StringInfoData buf; + + pq_begintypsend(&buf); + pq_sendhalf(&buf, arg1); + PG_RETURN_BYTEA_P(pq_endtypsend(&buf)); +} + +/* + * Convert integer to half + */ +PGDLLEXPORT PG_FUNCTION_INFO_V1(integer_to_half); +Datum +integer_to_half(PG_FUNCTION_ARGS) +{ + int32 i = PG_GETARG_INT32(0); + + /* TODO Figure out correct error */ + float f = (float) i; + half h = Float4ToHalf(f); + + PG_RETURN_HALF(h); +} + +/* + * Convert numeric to half + */ +PGDLLEXPORT PG_FUNCTION_INFO_V1(numeric_to_half); +Datum +numeric_to_half(PG_FUNCTION_ARGS) +{ + Numeric num = PG_GETARG_NUMERIC(0); + float f = DatumGetFloat4(DirectFunctionCall1(numeric_float4, NumericGetDatum(num))); + half h = Float4ToHalf(f); + + PG_RETURN_HALF(h); +} + +/* + * Convert float4 to half + */ +PGDLLEXPORT PG_FUNCTION_INFO_V1(float4_to_half); +Datum +float4_to_half(PG_FUNCTION_ARGS) +{ + float f = PG_GETARG_FLOAT4(0); + half h = Float4ToHalf(f); + + PG_RETURN_HALF(h); +} + +/* + * Get the L2 distance between half arrays + */ +PGDLLEXPORT PG_FUNCTION_INFO_V1(half_l2_distance); +Datum +half_l2_distance(PG_FUNCTION_ARGS) +{ + ArrayType *a = PG_GETARG_ARRAYTYPE_P(0); + ArrayType *b = PG_GETARG_ARRAYTYPE_P(1); + half *ax = (half *) ARR_DATA_PTR(a); + half *bx = (half *) ARR_DATA_PTR(b); + float distance = 0.0; + int dim = CheckDims(a, b); + + /* Auto-vectorized */ + for (int i = 0; i < dim; i++) + { + float diff = HalfToFloat4(ax[i]) - HalfToFloat4(bx[i]); + + distance += diff * diff; + } + + PG_RETURN_FLOAT8(sqrt((double) distance)); +} + +/* + * Get the L2 squared distance between half arrays + */ +PGDLLEXPORT PG_FUNCTION_INFO_V1(half_l2_squared_distance); +Datum +half_l2_squared_distance(PG_FUNCTION_ARGS) +{ + ArrayType *a = PG_GETARG_ARRAYTYPE_P(0); + ArrayType *b = PG_GETARG_ARRAYTYPE_P(1); + half *ax = (half *) ARR_DATA_PTR(a); + half *bx = (half *) ARR_DATA_PTR(b); + float distance = 0.0; + int dim = CheckDims(a, b); + + /* Auto-vectorized */ + for (int i = 0; i < dim; i++) + { + float diff = HalfToFloat4(ax[i]) - HalfToFloat4(bx[i]); + + distance += diff * diff; + } + + PG_RETURN_FLOAT8((double) distance); +} + +/* + * Get the inner product of two half arrays + */ +PGDLLEXPORT PG_FUNCTION_INFO_V1(half_inner_product); +Datum +half_inner_product(PG_FUNCTION_ARGS) +{ + ArrayType *a = PG_GETARG_ARRAYTYPE_P(0); + ArrayType *b = PG_GETARG_ARRAYTYPE_P(1); + half *ax = (half *) ARR_DATA_PTR(a); + half *bx = (half *) ARR_DATA_PTR(b); + float distance = 0.0; + int dim = CheckDims(a, b); + + /* Auto-vectorized */ + for (int i = 0; i < dim; i++) + distance += HalfToFloat4(ax[i]) * HalfToFloat4(bx[i]); + + PG_RETURN_FLOAT8((double) distance); +} + +/* + * Get the negative inner product of two half arrays + */ +PGDLLEXPORT PG_FUNCTION_INFO_V1(half_negative_inner_product); +Datum +half_negative_inner_product(PG_FUNCTION_ARGS) +{ + ArrayType *a = PG_GETARG_ARRAYTYPE_P(0); + ArrayType *b = PG_GETARG_ARRAYTYPE_P(1); + half *ax = (half *) ARR_DATA_PTR(a); + half *bx = (half *) ARR_DATA_PTR(b); + float distance = 0.0; + int dim = CheckDims(a, b); + + /* Auto-vectorized */ + for (int i = 0; i < dim; i++) + distance += HalfToFloat4(ax[i]) * HalfToFloat4(bx[i]); + + PG_RETURN_FLOAT8((double) distance * -1); +} + +/* + * Get the cosine distance between two half arrays + */ +PGDLLEXPORT PG_FUNCTION_INFO_V1(half_cosine_distance); +Datum +half_cosine_distance(PG_FUNCTION_ARGS) +{ + ArrayType *a = PG_GETARG_ARRAYTYPE_P(0); + ArrayType *b = PG_GETARG_ARRAYTYPE_P(1); + half *ax = (half *) ARR_DATA_PTR(a); + half *bx = (half *) ARR_DATA_PTR(b); + float distance = 0.0; + float norma = 0.0; + float normb = 0.0; + double similarity; + int dim = CheckDims(a, b); + + /* Auto-vectorized */ + for (int i = 0; i < dim; i++) + { + float axi = HalfToFloat4(ax[i]); + float bxi = HalfToFloat4(bx[i]); + + distance += axi * bxi; + norma += axi * axi; + normb += bxi * bxi; + } + + /* Use sqrt(a * b) over sqrt(a) * sqrt(b) */ + similarity = (double) distance / sqrt((double) norma * (double) normb); + +#ifdef _MSC_VER + /* /fp:fast may not propagate NaN */ + if (isnan(similarity)) + PG_RETURN_FLOAT8(NAN); +#endif + + /* Keep in range */ + if (similarity > 1) + similarity = 1; + else if (similarity < -1) + similarity = -1; + + PG_RETURN_FLOAT8(1 - similarity); +} + +/* + * Get the L1 distance between two half arrays + */ +PGDLLEXPORT PG_FUNCTION_INFO_V1(half_l1_distance); +Datum +half_l1_distance(PG_FUNCTION_ARGS) +{ + ArrayType *a = PG_GETARG_ARRAYTYPE_P(0); + ArrayType *b = PG_GETARG_ARRAYTYPE_P(1); + half *ax = (half *) ARR_DATA_PTR(a); + half *bx = (half *) ARR_DATA_PTR(b); + float distance = 0.0; + int dim = CheckDims(a, b); + + /* Auto-vectorized */ + for (int i = 0; i < dim; i++) + distance += fabsf(HalfToFloat4(ax[i]) - HalfToFloat4(bx[i])); + + PG_RETURN_FLOAT8((double) distance); +} diff --git a/src/half.h b/src/half.h new file mode 100644 index 0000000..6923c67 --- /dev/null +++ b/src/half.h @@ -0,0 +1,28 @@ +#ifndef HALF_H +#define HALF_H + +#define __STDC_WANT_IEC_60559_TYPES_EXT__ + +#include + +/* _Float16 and __fp16 are not supported on x86_64 with GCC 11 */ +#if defined(__is_identifier) +#if __is_identifier(_Float16) +#define FLT16_SUPPORT +#endif +#elif defined(FLT16_MAX) +#define FLT16_SUPPORT +#endif + +#ifdef FLT16_SUPPORT +#define half _Float16 +#define HALF_MAX FLT16_MAX +#else +#define half uint16 +#define HALF_MAX 65504 +#endif + +#define PG_GETARG_HALF(n) DatumGetHalf(PG_GETARG_DATUM(n)) +#define PG_RETURN_HALF(x) return HalfGetDatum(x) + +#endif diff --git a/test/expected/copy.out b/test/expected/copy.out index 36d4620..827cbfe 100644 --- a/test/expected/copy.out +++ b/test/expected/copy.out @@ -1,15 +1,15 @@ -CREATE TABLE t (val vector(3)); -INSERT INTO t (val) VALUES ('[0,0,0]'), ('[1,2,3]'), ('[1,1,1]'), (NULL); -CREATE TABLE t2 (val vector(3)); +CREATE TABLE t (val vector(3), val2 half[]); +INSERT INTO t (val, val2) VALUES ('[0,0,0]', '{0,0,0}'), ('[1,2,3]', '{1,2,3}'), ('[1,1,1]', '{1,1,1}'), (NULL, NULL); +CREATE TABLE t2 (val vector(3), val2 half[]); \copy t TO 'results/data.bin' WITH (FORMAT binary) \copy t2 FROM 'results/data.bin' WITH (FORMAT binary) SELECT * FROM t2 ORDER BY val; - val ---------- - [0,0,0] - [1,1,1] - [1,2,3] - + val | val2 +---------+--------- + [0,0,0] | {0,0,0} + [1,1,1] | {1,1,1} + [1,2,3] | {1,2,3} + | (4 rows) DROP TABLE t; diff --git a/test/expected/functions.out b/test/expected/functions.out index 2840688..1bc5616 100644 --- a/test/expected/functions.out +++ b/test/expected/functions.out @@ -54,105 +54,105 @@ SELECT vector_norm('[3e37,4e37]')::real; 5e+37 (1 row) -SELECT l2_distance('[0,0]', '[3,4]'); +SELECT l2_distance('[0,0]'::vector, '[3,4]'); l2_distance ------------- 5 (1 row) -SELECT l2_distance('[0,0]', '[0,1]'); +SELECT l2_distance('[0,0]'::vector, '[0,1]'); l2_distance ------------- 1 (1 row) -SELECT l2_distance('[1,2]', '[3]'); +SELECT l2_distance('[1,2]'::vector, '[3]'); ERROR: different vector dimensions 2 and 1 -SELECT l2_distance('[3e38]', '[-3e38]'); +SELECT l2_distance('[3e38]'::vector, '[-3e38]'); l2_distance ------------- Infinity (1 row) -SELECT inner_product('[1,2]', '[3,4]'); +SELECT inner_product('[1,2]'::vector, '[3,4]'); inner_product --------------- 11 (1 row) -SELECT inner_product('[1,2]', '[3]'); +SELECT inner_product('[1,2]'::vector, '[3]'); ERROR: different vector dimensions 2 and 1 -SELECT inner_product('[3e38]', '[3e38]'); +SELECT inner_product('[3e38]'::vector, '[3e38]'); inner_product --------------- Infinity (1 row) -SELECT cosine_distance('[1,2]', '[2,4]'); +SELECT cosine_distance('[1,2]'::vector, '[2,4]'); cosine_distance ----------------- 0 (1 row) -SELECT cosine_distance('[1,2]', '[0,0]'); +SELECT cosine_distance('[1,2]'::vector, '[0,0]'); cosine_distance ----------------- NaN (1 row) -SELECT cosine_distance('[1,1]', '[1,1]'); +SELECT cosine_distance('[1,1]'::vector, '[1,1]'); cosine_distance ----------------- 0 (1 row) -SELECT cosine_distance('[1,0]', '[0,2]'); +SELECT cosine_distance('[1,0]'::vector, '[0,2]'); cosine_distance ----------------- 1 (1 row) -SELECT cosine_distance('[1,1]', '[-1,-1]'); +SELECT cosine_distance('[1,1]'::vector, '[-1,-1]'); cosine_distance ----------------- 2 (1 row) -SELECT cosine_distance('[1,2]', '[3]'); +SELECT cosine_distance('[1,2]'::vector, '[3]'); ERROR: different vector dimensions 2 and 1 -SELECT cosine_distance('[1,1]', '[1.1,1.1]'); +SELECT cosine_distance('[1,1]'::vector, '[1.1,1.1]'); cosine_distance ----------------- 0 (1 row) -SELECT cosine_distance('[1,1]', '[-1.1,-1.1]'); +SELECT cosine_distance('[1,1]'::vector, '[-1.1,-1.1]'); cosine_distance ----------------- 2 (1 row) -SELECT cosine_distance('[3e38]', '[3e38]'); +SELECT cosine_distance('[3e38]'::vector, '[3e38]'); cosine_distance ----------------- NaN (1 row) -SELECT l1_distance('[0,0]', '[3,4]'); +SELECT l1_distance('[0,0]'::vector, '[3,4]'); l1_distance ------------- 7 (1 row) -SELECT l1_distance('[0,0]', '[0,1]'); +SELECT l1_distance('[0,0]'::vector, '[0,1]'); l1_distance ------------- 1 (1 row) -SELECT l1_distance('[1,2]', '[3]'); +SELECT l1_distance('[1,2]'::vector, '[3]'); ERROR: different vector dimensions 2 and 1 -SELECT l1_distance('[3e38]', '[-3e38]'); +SELECT l1_distance('[3e38]'::vector, '[-3e38]'); l1_distance ------------- Infinity diff --git a/test/expected/half.out b/test/expected/half.out new file mode 100644 index 0000000..5e54327 --- /dev/null +++ b/test/expected/half.out @@ -0,0 +1,182 @@ +SELECT '1.5'::half; + half +------ + 1.5 +(1 row) + +SELECT '65504'::half; + half +------- + 65504 +(1 row) + +SELECT '65505'::half; +ERROR: "65505" is out of range for type half +LINE 1: SELECT '65505'::half; + ^ +SELECT '-65504'::half; + half +-------- + -65504 +(1 row) + +SELECT '-65505'::half; +ERROR: "-65505" is out of range for type half +LINE 1: SELECT '-65505'::half; + ^ +SELECT ''::half; +ERROR: invalid input syntax for type half: "" +LINE 1: SELECT ''::half; + ^ +SELECT ' '::half; +ERROR: invalid input syntax for type half: " " +LINE 1: SELECT ' '::half; + ^ +SELECT '-'::half; +ERROR: invalid input syntax for type half: "-" +LINE 1: SELECT '-'::half; + ^ +SELECT ' 1.5'::half; + half +------ + 1.5 +(1 row) + +SELECT '1.5 '::half; + half +------ + 1.5 +(1 row) + +SELECT '1.5a'::half; +ERROR: invalid input syntax for type half: "1.5a" +LINE 1: SELECT '1.5a'::half; + ^ +SELECT '{1,2,3}'::half[]; + half +--------- + {1,2,3} +(1 row) + +SELECT '65505'::integer::half; + half +------- + 65504 +(1 row) + +SELECT 'NaN'::real::half; + half +------ + NaN +(1 row) + +SELECT 'Infinity'::real::half; + half +---------- + Infinity +(1 row) + +SELECT l2_distance('{0,0}'::half[], '{3,4}'::half[]); + l2_distance +------------- + 5 +(1 row) + +SELECT l2_distance('{0,0}'::half[], '{0,1}'::half[]); + l2_distance +------------- + 1 +(1 row) + +SELECT l2_distance('{1,2}'::half[], '{3}'::half[]); +ERROR: different dimensions 2 and 1 +SELECT '{0,0}'::half[] <-> '{3,4}'::half[]; + ?column? +---------- + 5 +(1 row) + +SELECT inner_product('{1,2}'::half[], '{3,4}'::half[]); + inner_product +--------------- + 11 +(1 row) + +SELECT inner_product('{1,2}'::half[], '{3}'::half[]); +ERROR: different dimensions 2 and 1 +SELECT inner_product('{65504}'::half[], '{65504}'::half[]); + inner_product +--------------- + 4290774016 +(1 row) + +SELECT '{1,2}'::half[] <#> '{3,4}'::half[]; + ?column? +---------- + -11 +(1 row) + +SELECT cosine_distance('{1,2}'::half[], '{2,4}'::half[]); + cosine_distance +----------------- + 0 +(1 row) + +SELECT cosine_distance('{1,2}'::half[], '{0,0}'::half[]); + cosine_distance +----------------- + NaN +(1 row) + +SELECT cosine_distance('{1,1}'::half[], '{1,1}'::half[]); + cosine_distance +----------------- + 0 +(1 row) + +SELECT cosine_distance('{1,0}'::half[], '{0,2}'::half[]); + cosine_distance +----------------- + 1 +(1 row) + +SELECT cosine_distance('{1,1}'::half[], '{-1,-1}'::half[]); + cosine_distance +----------------- + 2 +(1 row) + +SELECT cosine_distance('{1,2}'::half[], '{3}'::half[]); +ERROR: different dimensions 2 and 1 +SELECT cosine_distance('{1,1}'::half[], '{1.1,1.1}'::half[]); + cosine_distance +----------------- + 0 +(1 row) + +SELECT cosine_distance('{1,1}'::half[], '{-1.1,-1.1}'::half[]); + cosine_distance +----------------- + 2 +(1 row) + +SELECT '{1,2}'::half[] <=> '{2,4}'::half[]; + ?column? +---------- + 0 +(1 row) + +SELECT l1_distance('{0,0}'::half[], '{3,4}'); + l1_distance +------------- + 7 +(1 row) + +SELECT l1_distance('{0,0}'::half[], '{0,1}'); + l1_distance +------------- + 1 +(1 row) + +SELECT l1_distance('{1,2}'::half[], '{3}'); +ERROR: different dimensions 2 and 1 diff --git a/test/sql/copy.sql b/test/sql/copy.sql index 2820090..777fcfc 100644 --- a/test/sql/copy.sql +++ b/test/sql/copy.sql @@ -1,7 +1,7 @@ -CREATE TABLE t (val vector(3)); -INSERT INTO t (val) VALUES ('[0,0,0]'), ('[1,2,3]'), ('[1,1,1]'), (NULL); +CREATE TABLE t (val vector(3), val2 half[]); +INSERT INTO t (val, val2) VALUES ('[0,0,0]', '{0,0,0}'), ('[1,2,3]', '{1,2,3}'), ('[1,1,1]', '{1,1,1}'), (NULL, NULL); -CREATE TABLE t2 (val vector(3)); +CREATE TABLE t2 (val vector(3), val2 half[]); \copy t TO 'results/data.bin' WITH (FORMAT binary) \copy t2 FROM 'results/data.bin' WITH (FORMAT binary) diff --git a/test/sql/functions.sql b/test/sql/functions.sql index 914df36..3311adc 100644 --- a/test/sql/functions.sql +++ b/test/sql/functions.sql @@ -13,29 +13,29 @@ SELECT vector_norm('[3,4]'); SELECT vector_norm('[0,1]'); SELECT vector_norm('[3e37,4e37]')::real; -SELECT l2_distance('[0,0]', '[3,4]'); -SELECT l2_distance('[0,0]', '[0,1]'); -SELECT l2_distance('[1,2]', '[3]'); -SELECT l2_distance('[3e38]', '[-3e38]'); +SELECT l2_distance('[0,0]'::vector, '[3,4]'); +SELECT l2_distance('[0,0]'::vector, '[0,1]'); +SELECT l2_distance('[1,2]'::vector, '[3]'); +SELECT l2_distance('[3e38]'::vector, '[-3e38]'); -SELECT inner_product('[1,2]', '[3,4]'); -SELECT inner_product('[1,2]', '[3]'); -SELECT inner_product('[3e38]', '[3e38]'); +SELECT inner_product('[1,2]'::vector, '[3,4]'); +SELECT inner_product('[1,2]'::vector, '[3]'); +SELECT inner_product('[3e38]'::vector, '[3e38]'); -SELECT cosine_distance('[1,2]', '[2,4]'); -SELECT cosine_distance('[1,2]', '[0,0]'); -SELECT cosine_distance('[1,1]', '[1,1]'); -SELECT cosine_distance('[1,0]', '[0,2]'); -SELECT cosine_distance('[1,1]', '[-1,-1]'); -SELECT cosine_distance('[1,2]', '[3]'); -SELECT cosine_distance('[1,1]', '[1.1,1.1]'); -SELECT cosine_distance('[1,1]', '[-1.1,-1.1]'); -SELECT cosine_distance('[3e38]', '[3e38]'); +SELECT cosine_distance('[1,2]'::vector, '[2,4]'); +SELECT cosine_distance('[1,2]'::vector, '[0,0]'); +SELECT cosine_distance('[1,1]'::vector, '[1,1]'); +SELECT cosine_distance('[1,0]'::vector, '[0,2]'); +SELECT cosine_distance('[1,1]'::vector, '[-1,-1]'); +SELECT cosine_distance('[1,2]'::vector, '[3]'); +SELECT cosine_distance('[1,1]'::vector, '[1.1,1.1]'); +SELECT cosine_distance('[1,1]'::vector, '[-1.1,-1.1]'); +SELECT cosine_distance('[3e38]'::vector, '[3e38]'); -SELECT l1_distance('[0,0]', '[3,4]'); -SELECT l1_distance('[0,0]', '[0,1]'); -SELECT l1_distance('[1,2]', '[3]'); -SELECT l1_distance('[3e38]', '[-3e38]'); +SELECT l1_distance('[0,0]'::vector, '[3,4]'); +SELECT l1_distance('[0,0]'::vector, '[0,1]'); +SELECT l1_distance('[1,2]'::vector, '[3]'); +SELECT l1_distance('[3e38]'::vector, '[-3e38]'); SELECT avg(v) FROM unnest(ARRAY['[1,2,3]'::vector, '[3,5,7]']) v; SELECT avg(v) FROM unnest(ARRAY['[1,2,3]'::vector, '[3,5,7]', NULL]) v; diff --git a/test/sql/half.sql b/test/sql/half.sql new file mode 100644 index 0000000..3db34ba --- /dev/null +++ b/test/sql/half.sql @@ -0,0 +1,42 @@ +SELECT '1.5'::half; +SELECT '65504'::half; +SELECT '65505'::half; +SELECT '-65504'::half; +SELECT '-65505'::half; + +SELECT ''::half; +SELECT ' '::half; +SELECT '-'::half; +SELECT ' 1.5'::half; +SELECT '1.5 '::half; +SELECT '1.5a'::half; + +SELECT '{1,2,3}'::half[]; + +SELECT '65505'::integer::half; +SELECT 'NaN'::real::half; +SELECT 'Infinity'::real::half; + +SELECT l2_distance('{0,0}'::half[], '{3,4}'::half[]); +SELECT l2_distance('{0,0}'::half[], '{0,1}'::half[]); +SELECT l2_distance('{1,2}'::half[], '{3}'::half[]); +SELECT '{0,0}'::half[] <-> '{3,4}'::half[]; + +SELECT inner_product('{1,2}'::half[], '{3,4}'::half[]); +SELECT inner_product('{1,2}'::half[], '{3}'::half[]); +SELECT inner_product('{65504}'::half[], '{65504}'::half[]); +SELECT '{1,2}'::half[] <#> '{3,4}'::half[]; + +SELECT cosine_distance('{1,2}'::half[], '{2,4}'::half[]); +SELECT cosine_distance('{1,2}'::half[], '{0,0}'::half[]); +SELECT cosine_distance('{1,1}'::half[], '{1,1}'::half[]); +SELECT cosine_distance('{1,0}'::half[], '{0,2}'::half[]); +SELECT cosine_distance('{1,1}'::half[], '{-1,-1}'::half[]); +SELECT cosine_distance('{1,2}'::half[], '{3}'::half[]); +SELECT cosine_distance('{1,1}'::half[], '{1.1,1.1}'::half[]); +SELECT cosine_distance('{1,1}'::half[], '{-1.1,-1.1}'::half[]); +SELECT '{1,2}'::half[] <=> '{2,4}'::half[]; + +SELECT l1_distance('{0,0}'::half[], '{3,4}'); +SELECT l1_distance('{0,0}'::half[], '{0,1}'); +SELECT l1_distance('{1,2}'::half[], '{3}');