From 154e4334fb3005cc5e4d34a94d5ccfe1a6ee527f Mon Sep 17 00:00:00 2001 From: Andrew Kane Date: Fri, 11 Jun 2021 03:32:47 -0700 Subject: [PATCH] Added cast for numeric[] --- CHANGELOG.md | 4 ++++ sql/vector--0.1.6--0.1.7.sql | 8 ++++++++ sql/vector.sql | 6 ++++++ src/vector.c | 3 +++ test/expected/cast.out | 6 ++++++ test/sql/cast.sql | 1 + 6 files changed, 28 insertions(+) create mode 100644 sql/vector--0.1.6--0.1.7.sql diff --git a/CHANGELOG.md b/CHANGELOG.md index f6fccb4..45ece12 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,7 @@ +## 0.1.7 (unreleased) + +- Added cast for `numeric[]` + ## 0.1.6 (2021-06-09) - Fixed segmentation fault with `COUNT` diff --git a/sql/vector--0.1.6--0.1.7.sql b/sql/vector--0.1.6--0.1.7.sql new file mode 100644 index 0000000..8a607f7 --- /dev/null +++ b/sql/vector--0.1.6--0.1.7.sql @@ -0,0 +1,8 @@ +-- complain if script is sourced in psql, rather than via CREATE EXTENSION +\echo Use "ALTER EXTENSION vector UPDATE TO '0.1.6'" to load this file. \quit + +CREATE FUNCTION array_to_vector(numeric[], integer, boolean) RETURNS vector + AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + +CREATE CAST (numeric[] AS vector) + WITH FUNCTION array_to_vector(numeric[], integer, boolean) AS IMPLICIT; diff --git a/sql/vector.sql b/sql/vector.sql index 73892c0..65d38ec 100644 --- a/sql/vector.sql +++ b/sql/vector.sql @@ -97,6 +97,9 @@ CREATE FUNCTION array_to_vector(real[], integer, boolean) RETURNS vector CREATE FUNCTION array_to_vector(double precision[], integer, boolean) RETURNS vector AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; +CREATE FUNCTION array_to_vector(numeric[], integer, boolean) RETURNS vector + AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + -- casts CREATE CAST (vector AS vector) @@ -111,6 +114,9 @@ CREATE CAST (real[] AS vector) CREATE CAST (double precision[] AS vector) WITH FUNCTION array_to_vector(double precision[], integer, boolean) AS IMPLICIT; +CREATE CAST (numeric[] AS vector) + WITH FUNCTION array_to_vector(numeric[], integer, boolean) AS IMPLICIT; + -- operators CREATE OPERATOR <-> ( diff --git a/src/vector.c b/src/vector.c index c0669ee..31823b5 100644 --- a/src/vector.c +++ b/src/vector.c @@ -10,6 +10,7 @@ #include "utils/array.h" #include "utils/builtins.h" #include "utils/lsyscache.h" +#include "utils/numeric.h" #if PG_VERSION_NUM >= 120000 #include "utils/float.h" @@ -345,6 +346,8 @@ array_to_vector(PG_FUNCTION_ARGS) result->x[i] = DatumGetFloat8(elemsp[i]); else if (ARR_ELEMTYPE(array) == FLOAT4OID) result->x[i] = DatumGetFloat4(elemsp[i]); + else if (ARR_ELEMTYPE(array) == NUMERICOID) + result->x[i] = DatumGetFloat4(DirectFunctionCall1(numeric_float4, NumericGetDatum(elemsp[i]))); else ereport(ERROR, (errcode(ERRCODE_DATA_EXCEPTION), diff --git a/test/expected/cast.out b/test/expected/cast.out index 0086961..00a1cc0 100644 --- a/test/expected/cast.out +++ b/test/expected/cast.out @@ -6,6 +6,12 @@ SELECT ARRAY[1,2,3]::vector; [1,2,3] (1 row) +SELECT ARRAY[1.0,2.0,3.0]::vector; + array +--------- + [1,2,3] +(1 row) + SELECT ARRAY[1,2,3]::float4[]::vector; array --------- diff --git a/test/sql/cast.sql b/test/sql/cast.sql index 2997224..2b0f41e 100644 --- a/test/sql/cast.sql +++ b/test/sql/cast.sql @@ -2,6 +2,7 @@ SET client_min_messages = warning; CREATE EXTENSION IF NOT EXISTS vector; SELECT ARRAY[1,2,3]::vector; +SELECT ARRAY[1.0,2.0,3.0]::vector; SELECT ARRAY[1,2,3]::float4[]::vector; SELECT ARRAY[1,2,3]::float8[]::vector; SELECT '{NULL}'::real[]::vector;