diff --git a/sql/vector--0.6.2--0.7.0.sql b/sql/vector--0.6.2--0.7.0.sql index 0567775..f9483fc 100644 --- a/sql/vector--0.6.2--0.7.0.sql +++ b/sql/vector--0.6.2--0.7.0.sql @@ -415,6 +415,12 @@ CREATE FUNCTION vector_to_sparsevec(vector, integer, boolean) RETURNS sparsevec CREATE FUNCTION sparsevec_to_vector(sparsevec, integer, boolean) RETURNS vector AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; +CREATE FUNCTION halfvec_to_sparsevec(halfvec, integer, boolean) RETURNS sparsevec + AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + +CREATE FUNCTION sparsevec_to_halfvec(sparsevec, integer, boolean) RETURNS halfvec + AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + CREATE CAST (sparsevec AS sparsevec) WITH FUNCTION sparsevec(sparsevec, integer, boolean) AS IMPLICIT; @@ -424,6 +430,12 @@ CREATE CAST (sparsevec AS vector) CREATE CAST (vector AS sparsevec) WITH FUNCTION vector_to_sparsevec(vector, integer, boolean) AS IMPLICIT; +CREATE CAST (sparsevec AS halfvec) + WITH FUNCTION sparsevec_to_halfvec(sparsevec, integer, boolean) AS ASSIGNMENT; + +CREATE CAST (halfvec AS sparsevec) + WITH FUNCTION halfvec_to_sparsevec(halfvec, integer, boolean) AS IMPLICIT; + CREATE OPERATOR <-> ( LEFTARG = sparsevec, RIGHTARG = sparsevec, PROCEDURE = l2_distance, COMMUTATOR = '<->' diff --git a/sql/vector.sql b/sql/vector.sql index d2fe687..6dcdbd7 100644 --- a/sql/vector.sql +++ b/sql/vector.sql @@ -732,6 +732,12 @@ CREATE FUNCTION vector_to_sparsevec(vector, integer, boolean) RETURNS sparsevec CREATE FUNCTION sparsevec_to_vector(sparsevec, integer, boolean) RETURNS vector AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; +CREATE FUNCTION halfvec_to_sparsevec(halfvec, integer, boolean) RETURNS sparsevec + AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + +CREATE FUNCTION sparsevec_to_halfvec(sparsevec, integer, boolean) RETURNS halfvec + AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + -- sparsevec casts CREATE CAST (sparsevec AS sparsevec) @@ -743,6 +749,12 @@ CREATE CAST (sparsevec AS vector) CREATE CAST (vector AS sparsevec) WITH FUNCTION vector_to_sparsevec(vector, integer, boolean) AS IMPLICIT; +CREATE CAST (sparsevec AS halfvec) + WITH FUNCTION sparsevec_to_halfvec(sparsevec, integer, boolean) AS ASSIGNMENT; + +CREATE CAST (halfvec AS sparsevec) + WITH FUNCTION halfvec_to_sparsevec(halfvec, integer, boolean) AS IMPLICIT; + -- sparsevec operators CREATE OPERATOR <-> ( diff --git a/src/halfvec.c b/src/halfvec.c index d57ed28..582c137 100644 --- a/src/halfvec.c +++ b/src/halfvec.c @@ -11,6 +11,7 @@ #include "lib/stringinfo.h" #include "libpq/pqformat.h" #include "port.h" /* for strtof() */ +#include "sparsevec.h" #include "utils/array.h" #include "utils/builtins.h" #include "utils/float.h" @@ -1174,3 +1175,26 @@ halfvec_avg(PG_FUNCTION_ARGS) PG_RETURN_POINTER(result); } + +/* + * Convert sparse vector to half vector + */ +PGDLLEXPORT PG_FUNCTION_INFO_V1(sparsevec_to_halfvec); +Datum +sparsevec_to_halfvec(PG_FUNCTION_ARGS) +{ + SparseVector *svec = PG_GETARG_SPARSEVEC_P(0); + int32 typmod = PG_GETARG_INT32(1); + HalfVector *result; + int dim = svec->dim; + float *values = SPARSEVEC_VALUES(svec); + + CheckDim(dim); + CheckExpectedDim(typmod, dim); + + result = InitHalfVector(dim); + for (int i = 0; i < svec->nnz; i++) + result->x[svec->indices[i] - 1] = Float4ToHalf(values[i]); + + PG_RETURN_POINTER(result); +} diff --git a/src/sparsevec.c b/src/sparsevec.c index 22f7602..c38aff6 100644 --- a/src/sparsevec.c +++ b/src/sparsevec.c @@ -4,6 +4,8 @@ #include #include "fmgr.h" +#include "halfutils.h" +#include "halfvec.h" #include "libpq/pqformat.h" #include "sparsevec.h" #include "utils/array.h" @@ -612,6 +614,49 @@ vector_to_sparsevec(PG_FUNCTION_ARGS) PG_RETURN_POINTER(result); } +/* + * Convert half vector to sparse vector + */ +PGDLLEXPORT PG_FUNCTION_INFO_V1(halfvec_to_sparsevec); +Datum +halfvec_to_sparsevec(PG_FUNCTION_ARGS) +{ + HalfVector *vec = PG_GETARG_HALFVEC_P(0); + int32 typmod = PG_GETARG_INT32(1); + SparseVector *result; + int dim = vec->dim; + int nnz = 0; + float *values; + int j = 0; + + CheckDim(dim); + CheckExpectedDim(typmod, dim); + + for (int i = 0; i < dim; i++) + { + if (!HalfIsZero(vec->x[i])) + nnz++; + } + + result = InitSparseVector(dim, nnz); + values = SPARSEVEC_VALUES(result); + for (int i = 0; i < dim; i++) + { + if (!HalfIsZero(vec->x[i])) + { + /* Safety check */ + if (j >= result->nnz) + elog(ERROR, "safety check failed"); + + result->indices[j] = i + 1; + values[j] = HalfToFloat4(vec->x[i]); + j++; + } + } + + PG_RETURN_POINTER(result); +} + /* * Get the L2 squared distance between sparse vectors */ diff --git a/test/expected/cast.out b/test/expected/cast.out index 2b36c3d..ff34c8a 100644 --- a/test/expected/cast.out +++ b/test/expected/cast.out @@ -170,6 +170,36 @@ SELECT '{2:1.5,4:3.5}/5'::sparsevec::vector(4); ERROR: expected 4 dimensions, not 5 SELECT '{}/16001'::sparsevec::vector; ERROR: vector cannot have more than 16000 dimensions +SELECT '[0,1.5,0,3.5,0]'::halfvec::sparsevec; + sparsevec +----------------- + {2:1.5,4:3.5}/5 +(1 row) + +SELECT '[0,1.5,0,3.5,0]'::halfvec::sparsevec(5); + sparsevec +----------------- + {2:1.5,4:3.5}/5 +(1 row) + +SELECT '[0,1.5,0,3.5,0]'::halfvec::sparsevec(4); +ERROR: expected 4 dimensions, not 5 +SELECT '{2:1.5,4:3.5}/5'::sparsevec::halfvec; + halfvec +----------------- + [0,1.5,0,3.5,0] +(1 row) + +SELECT '{2:1.5,4:3.5}/5'::sparsevec::halfvec(5); + halfvec +----------------- + [0,1.5,0,3.5,0] +(1 row) + +SELECT '{2:1.5,4:3.5}/5'::sparsevec::halfvec(4); +ERROR: expected 4 dimensions, not 5 +SELECT '{}/16001'::sparsevec::halfvec; +ERROR: halfvec cannot have more than 16000 dimensions SELECT array_agg(n)::vector FROM generate_series(1, 16001) n; ERROR: vector cannot have more than 16000 dimensions SELECT array_to_vector(array_agg(n), 16001, false) FROM generate_series(1, 16001) n; diff --git a/test/sql/cast.sql b/test/sql/cast.sql index 06399ea..c5d1d3e 100644 --- a/test/sql/cast.sql +++ b/test/sql/cast.sql @@ -47,6 +47,15 @@ SELECT '{2:1.5,4:3.5}/5'::sparsevec::vector(5); SELECT '{2:1.5,4:3.5}/5'::sparsevec::vector(4); SELECT '{}/16001'::sparsevec::vector; +SELECT '[0,1.5,0,3.5,0]'::halfvec::sparsevec; +SELECT '[0,1.5,0,3.5,0]'::halfvec::sparsevec(5); +SELECT '[0,1.5,0,3.5,0]'::halfvec::sparsevec(4); + +SELECT '{2:1.5,4:3.5}/5'::sparsevec::halfvec; +SELECT '{2:1.5,4:3.5}/5'::sparsevec::halfvec(5); +SELECT '{2:1.5,4:3.5}/5'::sparsevec::halfvec(4); +SELECT '{}/16001'::sparsevec::halfvec; + SELECT array_agg(n)::vector FROM generate_series(1, 16001) n; SELECT array_to_vector(array_agg(n), 16001, false) FROM generate_series(1, 16001) n;