From d09aa9f873789f73a8dc1f046a09643a5e5b317a Mon Sep 17 00:00:00 2001 From: Andrew Kane Date: Mon, 4 Dec 2023 12:49:28 -0800 Subject: [PATCH] Added casts between half to double precision [skip ci] --- sql/vector--0.5.1--0.6.0.sql | 12 ++++++++++++ sql/vector.sql | 12 ++++++++++++ src/half.c | 36 ++++++++++++++++++++++++++++++++++++ test/expected/half.out | 12 ++++++++++++ test/sql/half.sql | 3 +++ 5 files changed, 75 insertions(+) diff --git a/sql/vector--0.5.1--0.6.0.sql b/sql/vector--0.5.1--0.6.0.sql index 3b9115e..2865352 100644 --- a/sql/vector--0.5.1--0.6.0.sql +++ b/sql/vector--0.5.1--0.6.0.sql @@ -49,6 +49,12 @@ CREATE FUNCTION float4_to_half(real, integer, boolean) RETURNS half CREATE FUNCTION half_to_float4(half, integer, boolean) RETURNS real AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; +CREATE FUNCTION float8_to_half(float8, integer, boolean) RETURNS half + AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + +CREATE FUNCTION half_to_float8(half, integer, boolean) RETURNS float8 + AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + CREATE FUNCTION integer_to_half(integer, integer, boolean) RETURNS half AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; @@ -64,6 +70,12 @@ CREATE CAST (real AS half) CREATE CAST (half AS real) WITH FUNCTION half_to_float4(half, integer, boolean) AS IMPLICIT; +CREATE CAST (float8 AS half) + WITH FUNCTION float8_to_half(float8, integer, boolean) AS IMPLICIT; + +CREATE CAST (half AS float8) + WITH FUNCTION half_to_float8(half, integer, boolean) AS IMPLICIT; + CREATE CAST (integer AS half) WITH FUNCTION integer_to_half(integer, integer, boolean) AS IMPLICIT; diff --git a/sql/vector.sql b/sql/vector.sql index dfa1c7c..ad1be73 100644 --- a/sql/vector.sql +++ b/sql/vector.sql @@ -347,6 +347,12 @@ CREATE FUNCTION float4_to_half(real, integer, boolean) RETURNS half CREATE FUNCTION half_to_float4(half, integer, boolean) RETURNS real AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; +CREATE FUNCTION float8_to_half(float8, integer, boolean) RETURNS half + AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + +CREATE FUNCTION half_to_float8(half, integer, boolean) RETURNS float8 + AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + CREATE FUNCTION integer_to_half(integer, integer, boolean) RETURNS half AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; @@ -364,6 +370,12 @@ CREATE CAST (real AS half) CREATE CAST (half AS real) WITH FUNCTION half_to_float4(half, integer, boolean) AS IMPLICIT; +CREATE CAST (float8 AS half) + WITH FUNCTION float8_to_half(float8, integer, boolean) AS IMPLICIT; + +CREATE CAST (half AS float8) + WITH FUNCTION half_to_float8(half, integer, boolean) AS IMPLICIT; + CREATE CAST (integer AS half) WITH FUNCTION integer_to_half(integer, integer, boolean) AS IMPLICIT; diff --git a/src/half.c b/src/half.c index 5199eff..bcbde77 100644 --- a/src/half.c +++ b/src/half.c @@ -298,6 +298,16 @@ Float4ToHalf(float num) return result; } +/* + * Convert a float8 to a half + */ +static half +Float8ToHalf(double num) +{ + /* TODO Convert directly for greater accuracy */ + return Float4ToHalf((float) num); +} + /* * Convert textual representation to internal representation */ @@ -466,6 +476,32 @@ half_to_float4(PG_FUNCTION_ARGS) PG_RETURN_FLOAT4(f); } +/* + * Convert float8 to half + */ +PGDLLEXPORT PG_FUNCTION_INFO_V1(float8_to_half); +Datum +float8_to_half(PG_FUNCTION_ARGS) +{ + float8 d = PG_GETARG_FLOAT8(0); + half h = Float8ToHalf(d); + + PG_RETURN_HALF(h); +} + +/* + * Convert half to float8 + */ +PGDLLEXPORT PG_FUNCTION_INFO_V1(half_to_float8); +Datum +half_to_float8(PG_FUNCTION_ARGS) +{ + half h = PG_GETARG_HALF(0); + float f = HalfToFloat4(h); + + PG_RETURN_FLOAT8((double) f); +} + /* * Get the L2 distance between half arrays */ diff --git a/test/expected/half.out b/test/expected/half.out index 95f4439..7188963 100644 --- a/test/expected/half.out +++ b/test/expected/half.out @@ -94,6 +94,18 @@ SELECT '1.5'::real::half; 1.5 (1 row) +SELECT '1.5'::half::double precision; + float8 +-------- + 1.5 +(1 row) + +SELECT '1.5'::double precision::half; + half +------ + 1.5 +(1 row) + SELECT '1.5'::half::numeric; numeric --------- diff --git a/test/sql/half.sql b/test/sql/half.sql index 7b01e1c..6588322 100644 --- a/test/sql/half.sql +++ b/test/sql/half.sql @@ -21,6 +21,9 @@ SELECT 'Infinity'::real::half; SELECT '1.5'::half::real; SELECT '1.5'::real::half; +SELECT '1.5'::half::double precision; +SELECT '1.5'::double precision::half; + SELECT '1.5'::half::numeric; SELECT '1.5'::numeric::half;