diff --git a/sql/vector.sql b/sql/vector.sql index 30d74c7..15f484c 100644 --- a/sql/vector.sql +++ b/sql/vector.sql @@ -748,6 +748,58 @@ CREATE FUNCTION minivec_l2_squared_distance(minivec, minivec) RETURNS float8 CREATE FUNCTION minivec_negative_inner_product(minivec, minivec) RETURNS float8 AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; +-- minivec cast functions + +CREATE FUNCTION minivec(minivec, integer, boolean) RETURNS minivec + AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + +CREATE FUNCTION minivec_to_vector(minivec, integer, boolean) RETURNS vector + AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + +CREATE FUNCTION vector_to_minivec(vector, integer, boolean) RETURNS minivec + AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + +CREATE FUNCTION array_to_minivec(integer[], integer, boolean) RETURNS minivec + AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + +CREATE FUNCTION array_to_minivec(real[], integer, boolean) RETURNS minivec + AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + +CREATE FUNCTION array_to_minivec(double precision[], integer, boolean) RETURNS minivec + AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + +CREATE FUNCTION array_to_minivec(numeric[], integer, boolean) RETURNS minivec + AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + +CREATE FUNCTION minivec_to_float4(minivec, integer, boolean) RETURNS real[] + AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + +-- minivec casts + +CREATE CAST (minivec AS minivec) + WITH FUNCTION minivec(minivec, integer, boolean) AS IMPLICIT; + +CREATE CAST (minivec AS vector) + WITH FUNCTION minivec_to_vector(minivec, integer, boolean) AS ASSIGNMENT; + +CREATE CAST (vector AS minivec) + WITH FUNCTION vector_to_minivec(vector, integer, boolean) AS IMPLICIT; + +CREATE CAST (minivec AS real[]) + WITH FUNCTION minivec_to_float4(minivec, integer, boolean) AS ASSIGNMENT; + +CREATE CAST (integer[] AS minivec) + WITH FUNCTION array_to_minivec(integer[], integer, boolean) AS ASSIGNMENT; + +CREATE CAST (real[] AS minivec) + WITH FUNCTION array_to_minivec(real[], integer, boolean) AS ASSIGNMENT; + +CREATE CAST (double precision[] AS minivec) + WITH FUNCTION array_to_minivec(double precision[], integer, boolean) AS ASSIGNMENT; + +CREATE CAST (numeric[] AS minivec) + WITH FUNCTION array_to_minivec(numeric[], integer, boolean) AS ASSIGNMENT; + -- minivec operators CREATE OPERATOR <-> ( diff --git a/src/minivec.c b/src/minivec.c index de66ac8..f73ab3c 100644 --- a/src/minivec.c +++ b/src/minivec.c @@ -348,6 +348,142 @@ minivec_send(PG_FUNCTION_ARGS) PG_RETURN_BYTEA_P(pq_endtypsend(&buf)); } +/* + * Convert fp8 vector to fp8 vector + * This is needed to check the type modifier + */ +FUNCTION_PREFIX PG_FUNCTION_INFO_V1(minivec); +Datum +minivec(PG_FUNCTION_ARGS) +{ + MiniVector *vec = PG_GETARG_MINIVEC_P(0); + int32 typmod = PG_GETARG_INT32(1); + + CheckExpectedDim(typmod, vec->dim); + + PG_RETURN_POINTER(vec); +} + +/* + * Convert array to fp8 vector + */ +FUNCTION_PREFIX PG_FUNCTION_INFO_V1(array_to_minivec); +Datum +array_to_minivec(PG_FUNCTION_ARGS) +{ + ArrayType *array = PG_GETARG_ARRAYTYPE_P(0); + int32 typmod = PG_GETARG_INT32(1); + MiniVector *result; + int16 typlen; + bool typbyval; + char typalign; + Datum *elemsp; + int nelemsp; + + if (ARR_NDIM(array) > 1) + ereport(ERROR, + (errcode(ERRCODE_DATA_EXCEPTION), + errmsg("array must be 1-D"))); + + if (ARR_HASNULL(array) && array_contains_nulls(array)) + ereport(ERROR, + (errcode(ERRCODE_NULL_VALUE_NOT_ALLOWED), + errmsg("array must not contain nulls"))); + + get_typlenbyvalalign(ARR_ELEMTYPE(array), &typlen, &typbyval, &typalign); + deconstruct_array(array, ARR_ELEMTYPE(array), typlen, typbyval, typalign, &elemsp, NULL, &nelemsp); + + CheckDim(nelemsp); + CheckExpectedDim(typmod, nelemsp); + + result = InitMiniVector(nelemsp); + + if (ARR_ELEMTYPE(array) == INT4OID) + { + for (int i = 0; i < nelemsp; i++) + result->x[i] = Float4ToFp8(DatumGetInt32(elemsp[i])); + } + else if (ARR_ELEMTYPE(array) == FLOAT8OID) + { + for (int i = 0; i < nelemsp; i++) + result->x[i] = Float4ToFp8(DatumGetFloat8(elemsp[i])); + } + else if (ARR_ELEMTYPE(array) == FLOAT4OID) + { + for (int i = 0; i < nelemsp; i++) + result->x[i] = Float4ToFp8(DatumGetFloat4(elemsp[i])); + } + else if (ARR_ELEMTYPE(array) == NUMERICOID) + { + for (int i = 0; i < nelemsp; i++) + result->x[i] = Float4ToFp8(DatumGetFloat4(DirectFunctionCall1(numeric_float4, elemsp[i]))); + } + else + { + ereport(ERROR, + (errcode(ERRCODE_DATA_EXCEPTION), + errmsg("unsupported array type"))); + } + + /* + * Free allocation from deconstruct_array. Do not free individual elements + * when pass-by-reference since they point to original array. + */ + pfree(elemsp); + + /* Check elements */ + for (int i = 0; i < result->dim; i++) + CheckElement(result->x[i]); + + PG_RETURN_POINTER(result); +} + +/* + * Convert fp8 vector to float4[] + */ +FUNCTION_PREFIX PG_FUNCTION_INFO_V1(minivec_to_float4); +Datum +minivec_to_float4(PG_FUNCTION_ARGS) +{ + MiniVector *vec = PG_GETARG_MINIVEC_P(0); + Datum *datums; + ArrayType *result; + + datums = (Datum *) palloc(sizeof(Datum) * vec->dim); + + for (int i = 0; i < vec->dim; i++) + datums[i] = Float4GetDatum(Fp8ToFloat4(vec->x[i])); + + /* Use TYPALIGN_INT for float4 */ + result = construct_array(datums, vec->dim, FLOAT4OID, sizeof(float4), true, TYPALIGN_INT); + + pfree(datums); + + PG_RETURN_POINTER(result); +} + +/* + * Convert vector to fp8 vector + */ +FUNCTION_PREFIX PG_FUNCTION_INFO_V1(vector_to_minivec); +Datum +vector_to_minivec(PG_FUNCTION_ARGS) +{ + Vector *vec = PG_GETARG_VECTOR_P(0); + int32 typmod = PG_GETARG_INT32(1); + MiniVector *result; + + CheckDim(vec->dim); + CheckExpectedDim(typmod, vec->dim); + + result = InitMiniVector(vec->dim); + + for (int i = 0; i < vec->dim; i++) + result->x[i] = Float4ToFp8(vec->x[i]); + + PG_RETURN_POINTER(result); +} + static float MinivecL2SquaredDistance(int dim, fp8 * ax, fp8 * bx) { diff --git a/src/minivec.h b/src/minivec.h index 5e1e2c2..ddca3fc 100644 --- a/src/minivec.h +++ b/src/minivec.h @@ -40,14 +40,13 @@ Fp8IsZero(fp8 num) return num == 0; } -float lookup[128] = {0, 0.00195312, 0.00390625, 0.00585938, 0.0078125, 0.00976562, 0.0117188, 0.0136719, 0.015625, 0.0175781, 0.0195312, 0.0214844, 0.0234375, 0.0253906, 0.0273438, 0.0292969, 0.03125, 0.0351562, 0.0390625, 0.0429688, 0.046875, 0.0507812, 0.0546875, 0.0585938, 0.0625, 0.0703125, 0.078125, 0.0859375, 0.09375, 0.101562, 0.109375, 0.117188, 0.125, 0.140625, 0.15625, 0.171875, 0.1875, 0.203125, 0.21875, 0.234375, 0.25, 0.28125, 0.3125, 0.34375, 0.375, 0.40625, 0.4375, 0.46875, 0.5, 0.5625, 0.625, 0.6875, 0.75, 0.8125, 0.875, 0.9375, 1, 1.125, 1.25, 1.375, 1.5, 1.625, 1.75, 1.875, 2, 2.25, 2.5, 2.75, 3, 3.25, 3.5, 3.75, 4, 4.5, 5, 5.5, 6, 6.5, 7, 7.5, 8, 9, 10, 11, 12, 13, 14, 15, 16, 18, 20, 22, 24, 26, 28, 30, 32, 36, 40, 44, 48, 52, 56, 60, 64, 72, 80, 88, 96, 104, 112, 120, 128, 144, 160, 176, 192, 208, 224, 240, 256, 288, 320, 352, 384, 416, 448, NAN}; - /* * Convert a fp8 to a float4 */ static inline float Fp8ToFloat4(fp8 num) { + float lookup[128] = {0, 0.00195312, 0.00390625, 0.00585938, 0.0078125, 0.00976562, 0.0117188, 0.0136719, 0.015625, 0.0175781, 0.0195312, 0.0214844, 0.0234375, 0.0253906, 0.0273438, 0.0292969, 0.03125, 0.0351562, 0.0390625, 0.0429688, 0.046875, 0.0507812, 0.0546875, 0.0585938, 0.0625, 0.0703125, 0.078125, 0.0859375, 0.09375, 0.101562, 0.109375, 0.117188, 0.125, 0.140625, 0.15625, 0.171875, 0.1875, 0.203125, 0.21875, 0.234375, 0.25, 0.28125, 0.3125, 0.34375, 0.375, 0.40625, 0.4375, 0.46875, 0.5, 0.5625, 0.625, 0.6875, 0.75, 0.8125, 0.875, 0.9375, 1, 1.125, 1.25, 1.375, 1.5, 1.625, 1.75, 1.875, 2, 2.25, 2.5, 2.75, 3, 3.25, 3.5, 3.75, 4, 4.5, 5, 5.5, 6, 6.5, 7, 7.5, 8, 9, 10, 11, 12, 13, 14, 15, 16, 18, 20, 22, 24, 26, 28, 30, 32, 36, 40, 44, 48, 52, 56, 60, 64, 72, 80, 88, 96, 104, 112, 120, 128, 144, 160, 176, 192, 208, 224, 240, 256, 288, 320, 352, 384, 416, 448, NAN}; float v = lookup[num & 0x7F]; return (num & 0x80) == 0x80 ? -v : v; @@ -132,4 +131,26 @@ Float4ToFp8Unchecked(float num) return result; } +/* + * Convert a float4 to a fp8 + */ +static inline fp8 +Float4ToFp8(float num) +{ + fp8 result = Float4ToFp8Unchecked(num); + + if (unlikely(Fp8IsNan(result)) && !isnan(num)) + { + char *buf = palloc(FLOAT_SHORTEST_DECIMAL_LEN); + + float_to_shortest_decimal_buf(num, buf); + + ereport(ERROR, + (errcode(ERRCODE_NUMERIC_VALUE_OUT_OF_RANGE), + errmsg("\"%s\" is out of range for type minivec", buf))); + } + + return result; +} + #endif diff --git a/src/vector.c b/src/vector.c index a5b2aac..2c421c9 100644 --- a/src/vector.c +++ b/src/vector.c @@ -13,6 +13,7 @@ #include "ivfflat.h" #include "lib/stringinfo.h" #include "libpq/pqformat.h" +#include "minivec.h" #include "port.h" /* for strtof() */ #include "sparsevec.h" #include "utils/array.h" @@ -542,6 +543,28 @@ halfvec_to_vector(PG_FUNCTION_ARGS) PG_RETURN_POINTER(result); } +/* + * Convert fp8 vector to vector + */ +FUNCTION_PREFIX PG_FUNCTION_INFO_V1(minivec_to_vector); +Datum +minivec_to_vector(PG_FUNCTION_ARGS) +{ + MiniVector *vec = PG_GETARG_MINIVEC_P(0); + int32 typmod = PG_GETARG_INT32(1); + Vector *result; + + CheckDim(vec->dim); + CheckExpectedDim(typmod, vec->dim); + + result = InitVector(vec->dim); + + for (int i = 0; i < vec->dim; i++) + result->x[i] = Fp8ToFloat4(vec->x[i]); + + PG_RETURN_POINTER(result); +} + VECTOR_TARGET_CLONES static float VectorL2SquaredDistance(int dim, float *ax, float *bx) { diff --git a/test/expected/cast.out b/test/expected/cast.out index c180fe6..34a57c3 100644 --- a/test/expected/cast.out +++ b/test/expected/cast.out @@ -140,6 +140,64 @@ SELECT '{1e-8,-1e-8}'::real[]::halfvec; [0,-0] (1 row) +SELECT '[1,2,3]'::vector::minivec; + minivec +--------- + [1,2,3] +(1 row) + +SELECT '[1,2,3]'::vector::minivec(3); + minivec +--------- + [1,2,3] +(1 row) + +SELECT '[1,2,3]'::vector::minivec(2); +ERROR: expected 2 dimensions, not 3 +SELECT '[465]'::vector::minivec; +ERROR: "465" is out of range for type minivec +SELECT '[1e-8]'::vector::minivec; + minivec +--------- + [0] +(1 row) + +SELECT '[1,2,3]'::minivec::vector; + vector +--------- + [1,2,3] +(1 row) + +SELECT '[1,2,3]'::minivec::vector(3); + vector +--------- + [1,2,3] +(1 row) + +SELECT '[1,2,3]'::minivec::vector(2); +ERROR: expected 2 dimensions, not 3 +SELECT '{1,2,3}'::real[]::minivec; + minivec +--------- + [1,2,3] +(1 row) + +SELECT '{1,2,3}'::real[]::minivec(3); + minivec +--------- + [1,2,3] +(1 row) + +SELECT '{1,2,3}'::real[]::minivec(2); +ERROR: expected 2 dimensions, not 3 +SELECT '{465,-465}'::real[]::minivec; +ERROR: "465" is out of range for type minivec +SELECT '{1e-8,-1e-8}'::real[]::minivec; + minivec +--------- + [0,-0] +(1 row) + SELECT '[0,1.5,0,3.5,0]'::vector::sparsevec; sparsevec ----------------- diff --git a/test/expected/minivec.out b/test/expected/minivec.out index 426a14e..ab5224c 100644 --- a/test/expected/minivec.out +++ b/test/expected/minivec.out @@ -134,11 +134,7 @@ SELECT '[1,2,3]'::minivec(3); (1 row) SELECT '[1,2,3]'::minivec(2); - minivec ---------- - [1,2,3] -(1 row) - +ERROR: expected 2 dimensions, not 3 SELECT '[1,2,3]'::minivec(3, 2); ERROR: invalid type modifier LINE 1: SELECT '[1,2,3]'::minivec(3, 2); @@ -163,11 +159,7 @@ SELECT unnest('{"[1,2,3]", "[4,5,6]"}'::minivec[]); (2 rows) SELECT '{"[1,2,3]"}'::minivec(2)[]; - minivec -------------- - {"[1,2,3]"} -(1 row) - +ERROR: expected 2 dimensions, not 3 SELECT '[1,2,3]'::minivec + '[4,5,6]'; ?column? ---------- @@ -211,9 +203,7 @@ SELECT '[1,2,3]'::minivec || '[4,5]'; (1 row) SELECT array_fill(0, ARRAY[16000])::minivec || '[1]'; -ERROR: cannot cast type integer[] to minivec -LINE 1: SELECT array_fill(0, ARRAY[16000])::minivec || '[1]'; - ^ +ERROR: minivec cannot have more than 16000 dimensions SELECT '[1,2,3]'::minivec < '[1,2,3]'; ?column? ---------- diff --git a/test/sql/cast.sql b/test/sql/cast.sql index fe83931..5db8436 100644 --- a/test/sql/cast.sql +++ b/test/sql/cast.sql @@ -38,6 +38,22 @@ SELECT '{1,2,3}'::real[]::halfvec(2); SELECT '{65520,-65520}'::real[]::halfvec; SELECT '{1e-8,-1e-8}'::real[]::halfvec; +SELECT '[1,2,3]'::vector::minivec; +SELECT '[1,2,3]'::vector::minivec(3); +SELECT '[1,2,3]'::vector::minivec(2); +SELECT '[465]'::vector::minivec; +SELECT '[1e-8]'::vector::minivec; + +SELECT '[1,2,3]'::minivec::vector; +SELECT '[1,2,3]'::minivec::vector(3); +SELECT '[1,2,3]'::minivec::vector(2); + +SELECT '{1,2,3}'::real[]::minivec; +SELECT '{1,2,3}'::real[]::minivec(3); +SELECT '{1,2,3}'::real[]::minivec(2); +SELECT '{465,-465}'::real[]::minivec; +SELECT '{1e-8,-1e-8}'::real[]::minivec; + SELECT '[0,1.5,0,3.5,0]'::vector::sparsevec; SELECT '[0,1.5,0,3.5,0]'::vector::sparsevec(5); SELECT '[0,1.5,0,3.5,0]'::vector::sparsevec(4);