diff --git a/sql/vector--0.7.4--0.8.0.sql b/sql/vector--0.7.4--0.8.0.sql index 93f591f..f140d09 100644 --- a/sql/vector--0.7.4--0.8.0.sql +++ b/sql/vector--0.7.4--0.8.0.sql @@ -81,9 +81,15 @@ CREATE FUNCTION intvec(intvec, integer, boolean) RETURNS intvec CREATE FUNCTION array_to_intvec(integer[], integer, boolean) RETURNS intvec AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; +CREATE FUNCTION intvec_to_int(intvec, integer, boolean) RETURNS int[] + AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + CREATE CAST (intvec AS intvec) WITH FUNCTION intvec(intvec, integer, boolean) AS IMPLICIT; +CREATE CAST (intvec AS int[]) + WITH FUNCTION intvec_to_int(intvec, integer, boolean) AS ASSIGNMENT; + CREATE CAST (integer[] AS intvec) WITH FUNCTION array_to_intvec(integer[], integer, boolean) AS ASSIGNMENT; diff --git a/sql/vector.sql b/sql/vector.sql index 2809fd3..dea65aa 100644 --- a/sql/vector.sql +++ b/sql/vector.sql @@ -735,11 +735,17 @@ CREATE FUNCTION intvec(intvec, integer, boolean) RETURNS intvec CREATE FUNCTION array_to_intvec(integer[], integer, boolean) RETURNS intvec AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; +CREATE FUNCTION intvec_to_int(intvec, integer, boolean) RETURNS int[] + AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + -- intvec casts CREATE CAST (intvec AS intvec) WITH FUNCTION intvec(intvec, integer, boolean) AS IMPLICIT; +CREATE CAST (intvec AS int[]) + WITH FUNCTION intvec_to_int(intvec, integer, boolean) AS ASSIGNMENT; + CREATE CAST (integer[] AS intvec) WITH FUNCTION array_to_intvec(integer[], integer, boolean) AS ASSIGNMENT; diff --git a/src/intvec.c b/src/intvec.c index 03716d1..30aa457 100644 --- a/src/intvec.c +++ b/src/intvec.c @@ -415,6 +415,29 @@ array_to_intvec(PG_FUNCTION_ARGS) PG_RETURN_POINTER(result); } +/* + * Convert int vector to int[] + */ +FUNCTION_PREFIX PG_FUNCTION_INFO_V1(intvec_to_int); +Datum +intvec_to_int(PG_FUNCTION_ARGS) +{ + IntVector *vec = PG_GETARG_INTVEC_P(0); + Datum *datums; + ArrayType *result; + + datums = (Datum *) palloc(sizeof(Datum) * vec->dim); + + for (int i = 0; i < vec->dim; i++) + datums[i] = Int32GetDatum((int) vec->x[i]); + + result = construct_array(datums, vec->dim, INT4OID, sizeof(int32), true, TYPALIGN_INT); + + pfree(datums); + + PG_RETURN_POINTER(result); +} + /* * Get the L2 distance between int vectors */ @@ -587,7 +610,7 @@ FUNCTION_PREFIX PG_FUNCTION_INFO_V1(intvec_vector_dims); Datum intvec_vector_dims(PG_FUNCTION_ARGS) { - IntVector *a = PG_GETARG_INTVEC_P(0); + IntVector *a = PG_GETARG_INTVEC_P(0); PG_RETURN_INT32(a->dim); } diff --git a/test/expected/cast.out b/test/expected/cast.out index 621c0da..b641f64 100644 --- a/test/expected/cast.out +++ b/test/expected/cast.out @@ -140,6 +140,12 @@ SELECT '{1e-8,-1e-8}'::real[]::halfvec; [0,-0] (1 row) +SELECT '[1,2,3]'::intvec::int[]; + int4 +--------- + {1,2,3} +(1 row) + SELECT '{1,2,3}'::int[]::intvec; intvec --------- diff --git a/test/sql/cast.sql b/test/sql/cast.sql index 34f94fe..8ce92fe 100644 --- a/test/sql/cast.sql +++ b/test/sql/cast.sql @@ -38,6 +38,8 @@ SELECT '{1,2,3}'::real[]::halfvec(2); SELECT '{65520,-65520}'::real[]::halfvec; SELECT '{1e-8,-1e-8}'::real[]::halfvec; +SELECT '[1,2,3]'::intvec::int[]; + SELECT '{1,2,3}'::int[]::intvec; SELECT '{1,2,3}'::int[]::intvec(3); SELECT '{1,2,3}'::int[]::intvec(2);