From 4f1a37963815a23deaadaeb7890e00a4d6a12be5 Mon Sep 17 00:00:00 2001 From: Andrew Kane Date: Thu, 28 Mar 2024 14:04:59 -0700 Subject: [PATCH] Added casting between vector and halfvec --- sql/vector--0.6.2--0.7.0.sql | 12 ++++++++++++ sql/vector.sql | 14 ++++++++++++++ src/halfvec.c | 19 +++++++++++++++++++ src/vector.c | 19 +++++++++++++++++++ test/expected/cast.out | 12 ++++++++++++ test/sql/cast.sql | 2 ++ 6 files changed, 78 insertions(+) diff --git a/sql/vector--0.6.2--0.7.0.sql b/sql/vector--0.6.2--0.7.0.sql index f63d28b..999bedf 100644 --- a/sql/vector--0.6.2--0.7.0.sql +++ b/sql/vector--0.6.2--0.7.0.sql @@ -114,3 +114,15 @@ CREATE OPERATOR CLASS halfvec_cosine_ops OPERATOR 1 <=> (halfvec, halfvec) FOR ORDER BY float_ops, FUNCTION 1 halfvec_negative_inner_product(halfvec, halfvec), FUNCTION 2 halfvec_norm(halfvec); + +CREATE FUNCTION halfvec_to_vector(halfvec, integer, boolean) RETURNS vector + AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + +CREATE FUNCTION vector_to_halfvec(vector, integer, boolean) RETURNS halfvec + AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + +CREATE CAST (halfvec AS vector) + WITH FUNCTION halfvec_to_vector(halfvec, integer, boolean) AS ASSIGNMENT; + +CREATE CAST (vector AS halfvec) + WITH FUNCTION vector_to_halfvec(vector, integer, boolean) AS ASSIGNMENT; diff --git a/sql/vector.sql b/sql/vector.sql index 9b78384..8e6bc8d 100644 --- a/sql/vector.sql +++ b/sql/vector.sql @@ -415,3 +415,17 @@ CREATE OPERATOR CLASS halfvec_cosine_ops OPERATOR 1 <=> (halfvec, halfvec) FOR ORDER BY float_ops, FUNCTION 1 halfvec_negative_inner_product(halfvec, halfvec), FUNCTION 2 halfvec_norm(halfvec); + +-- extension casts + +CREATE FUNCTION halfvec_to_vector(halfvec, integer, boolean) RETURNS vector + AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + +CREATE FUNCTION vector_to_halfvec(vector, integer, boolean) RETURNS halfvec + AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + +CREATE CAST (halfvec AS vector) + WITH FUNCTION halfvec_to_vector(halfvec, integer, boolean) AS ASSIGNMENT; + +CREATE CAST (vector AS halfvec) + WITH FUNCTION vector_to_halfvec(vector, integer, boolean) AS ASSIGNMENT; diff --git a/src/halfvec.c b/src/halfvec.c index 26cd782..6195fbe 100644 --- a/src/halfvec.c +++ b/src/halfvec.c @@ -14,6 +14,7 @@ #include "utils/float.h" #include "utils/lsyscache.h" #include "utils/numeric.h" +#include "vector.h" #if PG_VERSION_NUM < 130000 #define TYPALIGN_DOUBLE 'd' @@ -722,6 +723,24 @@ halfvec_to_float4(PG_FUNCTION_ARGS) PG_RETURN_POINTER(result); } +/* + * Convert half vector to vector + */ +PGDLLEXPORT PG_FUNCTION_INFO_V1(halfvec_to_vector); +Datum +halfvec_to_vector(PG_FUNCTION_ARGS) +{ + HalfVector *vec = PG_GETARG_HALFVEC_P(0); + + /* TODO Check vector dims in InitVector */ + Vector *result = InitVector(vec->dim); + + for (int i = 0; i < vec->dim; i++) + result->x[i] = HalfToFloat4(vec->x[i]); + + PG_RETURN_POINTER(result); +} + /* * Get the L2 distance between half vectors */ diff --git a/src/vector.c b/src/vector.c index 5f3cbbb..d082262 100644 --- a/src/vector.c +++ b/src/vector.c @@ -5,6 +5,7 @@ #include "catalog/pg_type.h" #include "common/shortest_dec.h" #include "fmgr.h" +#include "halfvec.h" #include "hnsw.h" #include "ivfflat.h" #include "lib/stringinfo.h" @@ -532,6 +533,24 @@ vector_to_float4(PG_FUNCTION_ARGS) PG_RETURN_POINTER(result); } +/* + * Convert vector to half vec + */ +PGDLLEXPORT PG_FUNCTION_INFO_V1(vector_to_halfvec); +Datum +vector_to_halfvec(PG_FUNCTION_ARGS) +{ + Vector *vec = PG_GETARG_VECTOR_P(0); + + /* TODO Check halfvec dims in InitHalfVector */ + HalfVector *result = InitHalfVector(vec->dim); + + for (int i = 0; i < vec->dim; i++) + result->x[i] = Float4ToHalf(vec->x[i]); + + PG_RETURN_POINTER(result); +} + /* * Get the L2 distance between vectors */ diff --git a/test/expected/cast.out b/test/expected/cast.out index 4824261..d438052 100644 --- a/test/expected/cast.out +++ b/test/expected/cast.out @@ -46,6 +46,18 @@ SELECT '[1,2,3]'::vector::real[]; {1,2,3} (1 row) +SELECT '[1,2,3]'::vector::halfvec; + halfvec +--------- + [1,2,3] +(1 row) + +SELECT '[1,2,3]'::halfvec::vector; + vector +--------- + [1,2,3] +(1 row) + SELECT array_agg(n)::vector FROM generate_series(1, 16001) n; ERROR: vector cannot have more than 16000 dimensions SELECT array_to_vector(array_agg(n), 16001, false) FROM generate_series(1, 16001) n; diff --git a/test/sql/cast.sql b/test/sql/cast.sql index c73ab07..fb574be 100644 --- a/test/sql/cast.sql +++ b/test/sql/cast.sql @@ -10,6 +10,8 @@ SELECT '{-Infinity}'::real[]::vector; SELECT '{}'::real[]::vector; SELECT '{{1}}'::real[]::vector; SELECT '[1,2,3]'::vector::real[]; +SELECT '[1,2,3]'::vector::halfvec; +SELECT '[1,2,3]'::halfvec::vector; SELECT array_agg(n)::vector FROM generate_series(1, 16001) n; SELECT array_to_vector(array_agg(n), 16001, false) FROM generate_series(1, 16001) n;