diff --git a/src/halfvec.c b/src/halfvec.c index 6195fbe..9e937b0 100644 --- a/src/halfvec.c +++ b/src/halfvec.c @@ -741,6 +741,27 @@ halfvec_to_vector(PG_FUNCTION_ARGS) PG_RETURN_POINTER(result); } +/* + * Convert vector to half vec + */ +PGDLLEXPORT PG_FUNCTION_INFO_V1(vector_to_halfvec); +Datum +vector_to_halfvec(PG_FUNCTION_ARGS) +{ + Vector *vec = PG_GETARG_VECTOR_P(0); + + /* TODO Check halfvec dims in InitHalfVector */ + HalfVector *result = InitHalfVector(vec->dim); + + for (int i = 0; i < vec->dim; i++) + { + result->x[i] = Float4ToHalfUnchecked(vec->x[i]); + CheckElement(result->x[i]); + } + + PG_RETURN_POINTER(result); +} + /* * Get the L2 distance between half vectors */ diff --git a/src/vector.c b/src/vector.c index d082262..2d0e96a 100644 --- a/src/vector.c +++ b/src/vector.c @@ -533,24 +533,6 @@ vector_to_float4(PG_FUNCTION_ARGS) PG_RETURN_POINTER(result); } -/* - * Convert vector to half vec - */ -PGDLLEXPORT PG_FUNCTION_INFO_V1(vector_to_halfvec); -Datum -vector_to_halfvec(PG_FUNCTION_ARGS) -{ - Vector *vec = PG_GETARG_VECTOR_P(0); - - /* TODO Check halfvec dims in InitHalfVector */ - HalfVector *result = InitHalfVector(vec->dim); - - for (int i = 0; i < vec->dim; i++) - result->x[i] = Float4ToHalf(vec->x[i]); - - PG_RETURN_POINTER(result); -} - /* * Get the L2 distance between vectors */ diff --git a/test/expected/cast.out b/test/expected/cast.out index d438052..f2e1ce7 100644 --- a/test/expected/cast.out +++ b/test/expected/cast.out @@ -58,6 +58,12 @@ SELECT '[1,2,3]'::halfvec::vector; [1,2,3] (1 row) +SELECT '[1e-8]'::vector::halfvec; + halfvec +--------- + [0] +(1 row) + SELECT array_agg(n)::vector FROM generate_series(1, 16001) n; ERROR: vector cannot have more than 16000 dimensions SELECT array_to_vector(array_agg(n), 16001, false) FROM generate_series(1, 16001) n; diff --git a/test/sql/cast.sql b/test/sql/cast.sql index fb574be..b86f4be 100644 --- a/test/sql/cast.sql +++ b/test/sql/cast.sql @@ -12,6 +12,7 @@ SELECT '{{1}}'::real[]::vector; SELECT '[1,2,3]'::vector::real[]; SELECT '[1,2,3]'::vector::halfvec; SELECT '[1,2,3]'::halfvec::vector; +SELECT '[1e-8]'::vector::halfvec; SELECT array_agg(n)::vector FROM generate_series(1, 16001) n; SELECT array_to_vector(array_agg(n), 16001, false) FROM generate_series(1, 16001) n;