diff --git a/CHANGELOG.md b/CHANGELOG.md index 90378df..30bc3c2 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,6 +1,10 @@ +## 0.1.8 (unreleased) + +- Added cast for `vector` to `real[]` + ## 0.1.7 (2021-06-13) -- Added cast for `numeric[]` +- Added cast for `numeric[]` to `vector` ## 0.1.6 (2021-06-09) diff --git a/sql/vector--0.1.7--0.1.8.sql b/sql/vector--0.1.7--0.1.8.sql new file mode 100644 index 0000000..5a387a7 --- /dev/null +++ b/sql/vector--0.1.7--0.1.8.sql @@ -0,0 +1,8 @@ +-- complain if script is sourced in psql, rather than via CREATE EXTENSION +\echo Use "ALTER EXTENSION vector UPDATE TO '0.1.8'" to load this file. \quit + +CREATE FUNCTION vector_to_float4(vector, integer, boolean) RETURNS real[] + AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + +CREATE CAST (vector AS real[]) + WITH FUNCTION vector_to_float4(vector, integer, boolean) AS IMPLICIT; diff --git a/sql/vector.sql b/sql/vector.sql index 65d38ec..452154f 100644 --- a/sql/vector.sql +++ b/sql/vector.sql @@ -100,6 +100,9 @@ CREATE FUNCTION array_to_vector(double precision[], integer, boolean) RETURNS ve CREATE FUNCTION array_to_vector(numeric[], integer, boolean) RETURNS vector AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; +CREATE FUNCTION vector_to_float4(vector, integer, boolean) RETURNS real[] + AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + -- casts CREATE CAST (vector AS vector) @@ -117,6 +120,9 @@ CREATE CAST (double precision[] AS vector) CREATE CAST (numeric[] AS vector) WITH FUNCTION array_to_vector(numeric[], integer, boolean) AS IMPLICIT; +CREATE CAST (vector AS real[]) + WITH FUNCTION vector_to_float4(vector, integer, boolean) AS IMPLICIT; + -- operators CREATE OPERATOR <-> ( diff --git a/src/vector.c b/src/vector.c index 31823b5..2710d21 100644 --- a/src/vector.c +++ b/src/vector.c @@ -359,6 +359,29 @@ array_to_vector(PG_FUNCTION_ARGS) PG_RETURN_POINTER(result); } +/* + * Convert vector to float4[] + */ +PG_FUNCTION_INFO_V1(vector_to_float4); +Datum +vector_to_float4(PG_FUNCTION_ARGS) +{ + Vector *vec = PG_GETARG_VECTOR_P(0); + Datum *d; + ArrayType *result; + int i; + + d = (Datum *) palloc(sizeof(Datum) * vec->dim); + + for (i = 0; i < vec->dim; i++) + d[i] = Float4GetDatum(vec->x[i]); + + /* Use TYPALIGN_INT for float4 */ + result = construct_array(d, vec->dim, FLOAT4OID, sizeof(float4), true, TYPALIGN_INT); + + PG_RETURN_POINTER(result); +} + /* * Get the L2 distance between vectors */ diff --git a/test/expected/cast.out b/test/expected/cast.out index 00a1cc0..213f5c8 100644 --- a/test/expected/cast.out +++ b/test/expected/cast.out @@ -34,5 +34,11 @@ SELECT '{-Infinity}'::real[]::vector; ERROR: infinite value not allowed in vector SELECT '{}'::real[]::vector; ERROR: vector must have at least 1 dimension +SELECT '[1,2,3]'::vector::real[]; + float4 +--------- + {1,2,3} +(1 row) + SELECT array_agg(n)::vector FROM generate_series(1, 1025) n; ERROR: vector cannot have more than 1024 dimensions diff --git a/test/sql/cast.sql b/test/sql/cast.sql index 2b0f41e..caf0507 100644 --- a/test/sql/cast.sql +++ b/test/sql/cast.sql @@ -10,4 +10,5 @@ SELECT '{NaN}'::real[]::vector; SELECT '{Infinity}'::real[]::vector; SELECT '{-Infinity}'::real[]::vector; SELECT '{}'::real[]::vector; +SELECT '[1,2,3]'::vector::real[]; SELECT array_agg(n)::vector FROM generate_series(1, 1025) n;