diff --git a/README.md b/README.md index 5bf5cf6..ad20db6 100644 --- a/README.md +++ b/README.md @@ -778,6 +778,8 @@ cosine_distance(halfvec, halfvec) → double precision | cosine distance | unrel inner_product(halfvec, halfvec) → double precision | inner product | unreleased l2_distance(halfvec, halfvec) → double precision | Euclidean distance | unreleased l1_distance(halfvec, halfvec) → double precision | taxicab distance | unreleased +quantize_binary(halfvec) → bit | quantize | unreleased +subvector(halfvec, integer, integer) → halfvec | subvector | unreleased ### Bit Type diff --git a/sql/vector--0.6.2--0.7.0.sql b/sql/vector--0.6.2--0.7.0.sql index f767d6a..3fdc8f1 100644 --- a/sql/vector--0.6.2--0.7.0.sql +++ b/sql/vector--0.6.2--0.7.0.sql @@ -74,6 +74,12 @@ CREATE FUNCTION l1_distance(halfvec, halfvec) RETURNS float8 CREATE FUNCTION halfvec_norm(halfvec) RETURNS float8 AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; +CREATE FUNCTION quantize_binary(halfvec) RETURNS bit + AS 'MODULE_PATHNAME', 'halfvec_quantize_binary' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + +CREATE FUNCTION subvector(halfvec, int, int) RETURNS halfvec + AS 'MODULE_PATHNAME', 'halfvec_subvector' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + CREATE FUNCTION halfvec_l2_squared_distance(halfvec, halfvec) RETURNS float8 AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; diff --git a/sql/vector.sql b/sql/vector.sql index f21b100..c694b1c 100644 --- a/sql/vector.sql +++ b/sql/vector.sql @@ -367,6 +367,12 @@ CREATE FUNCTION l1_distance(halfvec, halfvec) RETURNS float8 CREATE FUNCTION halfvec_norm(halfvec) RETURNS float8 AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; +CREATE FUNCTION quantize_binary(halfvec) RETURNS bit + AS 'MODULE_PATHNAME', 'halfvec_quantize_binary' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + +CREATE FUNCTION subvector(halfvec, int, int) RETURNS halfvec + AS 'MODULE_PATHNAME', 'halfvec_subvector' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + -- halfvec private functions CREATE FUNCTION halfvec_l2_squared_distance(halfvec, halfvec) RETURNS float8 diff --git a/src/halfvec.c b/src/halfvec.c index 36c1a85..d5fbf53 100644 --- a/src/halfvec.c +++ b/src/halfvec.c @@ -2,6 +2,7 @@ #include +#include "bitvector.h" #include "catalog/pg_type.h" #include "common/shortest_dec.h" #include "fmgr.h" @@ -967,3 +968,53 @@ halfvec_norm(PG_FUNCTION_ARGS) PG_RETURN_FLOAT8(sqrt(norm)); } + +/* + * Quantize a half vector + */ +PGDLLEXPORT PG_FUNCTION_INFO_V1(halfvec_quantize_binary); +Datum +halfvec_quantize_binary(PG_FUNCTION_ARGS) +{ + HalfVector *a = PG_GETARG_HALFVEC_P(0); + half *ax = a->x; + VarBit *result = InitBitVector(a->dim); + unsigned char *rx = VARBITS(result); + + for (int i = 0; i < a->dim; i++) + rx[i / 8] |= (HalfToFloat4(ax[i]) > 0) << (7 - (i % 8)); + + PG_RETURN_VARBIT_P(result); +} + +/* + * Get a subvector + */ +PGDLLEXPORT PG_FUNCTION_INFO_V1(halfvec_subvector); +Datum +halfvec_subvector(PG_FUNCTION_ARGS) +{ + HalfVector *a = PG_GETARG_HALFVEC_P(0); + int32 start = PG_GETARG_INT32(1); + int32 count = PG_GETARG_INT32(2); + int32 end = start + count; + half *ax = a->x; + HalfVector *result; + int dim; + + /* Indexing starts at 1, like substring */ + if (start < 1) + start = 1; + + if (end > a->dim) + end = a->dim + 1; + + dim = end - start; + CheckDim(dim); + result = InitHalfVector(dim); + + for (int i = 0; i < dim; i++) + result->x[i] = ax[start - 1 + i]; + + PG_RETURN_POINTER(result); +} diff --git a/test/expected/halfvec_functions.out b/test/expected/halfvec_functions.out index b0bd832..1241803 100644 --- a/test/expected/halfvec_functions.out +++ b/test/expected/halfvec_functions.out @@ -102,3 +102,45 @@ SELECT l1_distance('[0,0]'::halfvec, '[0,1]'); SELECT l1_distance('[1,2]'::halfvec, '[3]'); ERROR: different halfvec dimensions 2 and 1 +SELECT quantize_binary('[1,0,-1]'::halfvec); + quantize_binary +----------------- + 100 +(1 row) + +SELECT quantize_binary('[0,0.1,-0.2,-0.3,0.4,0.5,0.6,-0.7,0.8,-0.9,1]'::halfvec); + quantize_binary +----------------- + 01001110101 +(1 row) + +SELECT subvector('[1,2,3,4,5]'::halfvec, 1, 3); + subvector +----------- + [1,2,3] +(1 row) + +SELECT subvector('[1,2,3,4,5]'::halfvec, 3, 2); + subvector +----------- + [3,4] +(1 row) + +SELECT subvector('[1,2,3,4,5]'::halfvec, -1, 3); + subvector +----------- + [1] +(1 row) + +SELECT subvector('[1,2,3,4,5]'::halfvec, 3, 9); + subvector +----------- + [3,4,5] +(1 row) + +SELECT subvector('[1,2,3,4,5]'::halfvec, 1, 0); +ERROR: halfvec must have at least 1 dimension +SELECT subvector('[1,2,3,4,5]'::halfvec, 3, -1); +ERROR: halfvec must have at least 1 dimension +SELECT subvector('[1,2,3,4,5]'::halfvec, -1, 2); +ERROR: halfvec must have at least 1 dimension diff --git a/test/expected/vector_functions.out b/test/expected/vector_functions.out index 8be1525..b489ad0 100644 --- a/test/expected/vector_functions.out +++ b/test/expected/vector_functions.out @@ -272,47 +272,47 @@ SELECT l1_distance('[3e38]'::vector, '[-3e38]'); Infinity (1 row) -SELECT quantize_binary('[1,0,-1]'); +SELECT quantize_binary('[1,0,-1]'::vector); quantize_binary ----------------- 100 (1 row) -SELECT quantize_binary('[0,0.1,-0.2,-0.3,0.4,0.5,0.6,-0.7,0.8,-0.9,1]'); +SELECT quantize_binary('[0,0.1,-0.2,-0.3,0.4,0.5,0.6,-0.7,0.8,-0.9,1]'::vector); quantize_binary ----------------- 01001110101 (1 row) -SELECT subvector('[1,2,3,4,5]', 1, 3); +SELECT subvector('[1,2,3,4,5]'::vector, 1, 3); subvector ----------- [1,2,3] (1 row) -SELECT subvector('[1,2,3,4,5]', 3, 2); +SELECT subvector('[1,2,3,4,5]'::vector, 3, 2); subvector ----------- [3,4] (1 row) -SELECT subvector('[1,2,3,4,5]', -1, 3); +SELECT subvector('[1,2,3,4,5]'::vector, -1, 3); subvector ----------- [1] (1 row) -SELECT subvector('[1,2,3,4,5]', 3, 9); +SELECT subvector('[1,2,3,4,5]'::vector, 3, 9); subvector ----------- [3,4,5] (1 row) -SELECT subvector('[1,2,3,4,5]', 1, 0); +SELECT subvector('[1,2,3,4,5]'::vector, 1, 0); ERROR: vector must have at least 1 dimension -SELECT subvector('[1,2,3,4,5]', 3, -1); +SELECT subvector('[1,2,3,4,5]'::vector, 3, -1); ERROR: vector must have at least 1 dimension -SELECT subvector('[1,2,3,4,5]', -1, 2); +SELECT subvector('[1,2,3,4,5]'::vector, -1, 2); ERROR: vector must have at least 1 dimension SELECT avg(v) FROM unnest(ARRAY['[1,2,3]'::vector, '[3,5,7]']) v; avg diff --git a/test/sql/halfvec_functions.sql b/test/sql/halfvec_functions.sql index 17f28a5..a465506 100644 --- a/test/sql/halfvec_functions.sql +++ b/test/sql/halfvec_functions.sql @@ -21,3 +21,14 @@ SELECT '[1,2]'::halfvec <=> '[2,4]'; SELECT l1_distance('[0,0]'::halfvec, '[3,4]'); SELECT l1_distance('[0,0]'::halfvec, '[0,1]'); SELECT l1_distance('[1,2]'::halfvec, '[3]'); + +SELECT quantize_binary('[1,0,-1]'::halfvec); +SELECT quantize_binary('[0,0.1,-0.2,-0.3,0.4,0.5,0.6,-0.7,0.8,-0.9,1]'::halfvec); + +SELECT subvector('[1,2,3,4,5]'::halfvec, 1, 3); +SELECT subvector('[1,2,3,4,5]'::halfvec, 3, 2); +SELECT subvector('[1,2,3,4,5]'::halfvec, -1, 3); +SELECT subvector('[1,2,3,4,5]'::halfvec, 3, 9); +SELECT subvector('[1,2,3,4,5]'::halfvec, 1, 0); +SELECT subvector('[1,2,3,4,5]'::halfvec, 3, -1); +SELECT subvector('[1,2,3,4,5]'::halfvec, -1, 2); diff --git a/test/sql/vector_functions.sql b/test/sql/vector_functions.sql index 8e8cc6e..62b1a69 100644 --- a/test/sql/vector_functions.sql +++ b/test/sql/vector_functions.sql @@ -58,16 +58,16 @@ SELECT l1_distance('[0,0]'::vector, '[0,1]'); SELECT l1_distance('[1,2]'::vector, '[3]'); SELECT l1_distance('[3e38]'::vector, '[-3e38]'); -SELECT quantize_binary('[1,0,-1]'); -SELECT quantize_binary('[0,0.1,-0.2,-0.3,0.4,0.5,0.6,-0.7,0.8,-0.9,1]'); +SELECT quantize_binary('[1,0,-1]'::vector); +SELECT quantize_binary('[0,0.1,-0.2,-0.3,0.4,0.5,0.6,-0.7,0.8,-0.9,1]'::vector); -SELECT subvector('[1,2,3,4,5]', 1, 3); -SELECT subvector('[1,2,3,4,5]', 3, 2); -SELECT subvector('[1,2,3,4,5]', -1, 3); -SELECT subvector('[1,2,3,4,5]', 3, 9); -SELECT subvector('[1,2,3,4,5]', 1, 0); -SELECT subvector('[1,2,3,4,5]', 3, -1); -SELECT subvector('[1,2,3,4,5]', -1, 2); +SELECT subvector('[1,2,3,4,5]'::vector, 1, 3); +SELECT subvector('[1,2,3,4,5]'::vector, 3, 2); +SELECT subvector('[1,2,3,4,5]'::vector, -1, 3); +SELECT subvector('[1,2,3,4,5]'::vector, 3, 9); +SELECT subvector('[1,2,3,4,5]'::vector, 1, 0); +SELECT subvector('[1,2,3,4,5]'::vector, 3, -1); +SELECT subvector('[1,2,3,4,5]'::vector, -1, 2); SELECT avg(v) FROM unnest(ARRAY['[1,2,3]'::vector, '[3,5,7]']) v; SELECT avg(v) FROM unnest(ARRAY['[1,2,3]'::vector, '[3,5,7]', NULL]) v;