Added quantize_binary and subvector functions for halfvec

This commit is contained in:
Andrew Kane
2024-04-03 14:53:03 -07:00
parent 253acbccf4
commit aaa2d644ce
8 changed files with 136 additions and 18 deletions

View File

@@ -778,6 +778,8 @@ cosine_distance(halfvec, halfvec) → double precision | cosine distance | unrel
inner_product(halfvec, halfvec) → double precision | inner product | unreleased
l2_distance(halfvec, halfvec) → double precision | Euclidean distance | unreleased
l1_distance(halfvec, halfvec) → double precision | taxicab distance | unreleased
quantize_binary(halfvec) → bit | quantize | unreleased
subvector(halfvec, integer, integer) → halfvec | subvector | unreleased
### Bit Type

View File

@@ -74,6 +74,12 @@ CREATE FUNCTION l1_distance(halfvec, halfvec) RETURNS float8
CREATE FUNCTION halfvec_norm(halfvec) RETURNS float8
AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE;
CREATE FUNCTION quantize_binary(halfvec) RETURNS bit
AS 'MODULE_PATHNAME', 'halfvec_quantize_binary' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE;
CREATE FUNCTION subvector(halfvec, int, int) RETURNS halfvec
AS 'MODULE_PATHNAME', 'halfvec_subvector' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE;
CREATE FUNCTION halfvec_l2_squared_distance(halfvec, halfvec) RETURNS float8
AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE;

View File

@@ -367,6 +367,12 @@ CREATE FUNCTION l1_distance(halfvec, halfvec) RETURNS float8
CREATE FUNCTION halfvec_norm(halfvec) RETURNS float8
AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE;
CREATE FUNCTION quantize_binary(halfvec) RETURNS bit
AS 'MODULE_PATHNAME', 'halfvec_quantize_binary' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE;
CREATE FUNCTION subvector(halfvec, int, int) RETURNS halfvec
AS 'MODULE_PATHNAME', 'halfvec_subvector' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE;
-- halfvec private functions
CREATE FUNCTION halfvec_l2_squared_distance(halfvec, halfvec) RETURNS float8

View File

@@ -2,6 +2,7 @@
#include <math.h>
#include "bitvector.h"
#include "catalog/pg_type.h"
#include "common/shortest_dec.h"
#include "fmgr.h"
@@ -967,3 +968,53 @@ halfvec_norm(PG_FUNCTION_ARGS)
PG_RETURN_FLOAT8(sqrt(norm));
}
/*
* Quantize a half vector
*/
PGDLLEXPORT PG_FUNCTION_INFO_V1(halfvec_quantize_binary);
Datum
halfvec_quantize_binary(PG_FUNCTION_ARGS)
{
HalfVector *a = PG_GETARG_HALFVEC_P(0);
half *ax = a->x;
VarBit *result = InitBitVector(a->dim);
unsigned char *rx = VARBITS(result);
for (int i = 0; i < a->dim; i++)
rx[i / 8] |= (HalfToFloat4(ax[i]) > 0) << (7 - (i % 8));
PG_RETURN_VARBIT_P(result);
}
/*
* Get a subvector
*/
PGDLLEXPORT PG_FUNCTION_INFO_V1(halfvec_subvector);
Datum
halfvec_subvector(PG_FUNCTION_ARGS)
{
HalfVector *a = PG_GETARG_HALFVEC_P(0);
int32 start = PG_GETARG_INT32(1);
int32 count = PG_GETARG_INT32(2);
int32 end = start + count;
half *ax = a->x;
HalfVector *result;
int dim;
/* Indexing starts at 1, like substring */
if (start < 1)
start = 1;
if (end > a->dim)
end = a->dim + 1;
dim = end - start;
CheckDim(dim);
result = InitHalfVector(dim);
for (int i = 0; i < dim; i++)
result->x[i] = ax[start - 1 + i];
PG_RETURN_POINTER(result);
}

View File

@@ -102,3 +102,45 @@ SELECT l1_distance('[0,0]'::halfvec, '[0,1]');
SELECT l1_distance('[1,2]'::halfvec, '[3]');
ERROR: different halfvec dimensions 2 and 1
SELECT quantize_binary('[1,0,-1]'::halfvec);
quantize_binary
-----------------
100
(1 row)
SELECT quantize_binary('[0,0.1,-0.2,-0.3,0.4,0.5,0.6,-0.7,0.8,-0.9,1]'::halfvec);
quantize_binary
-----------------
01001110101
(1 row)
SELECT subvector('[1,2,3,4,5]'::halfvec, 1, 3);
subvector
-----------
[1,2,3]
(1 row)
SELECT subvector('[1,2,3,4,5]'::halfvec, 3, 2);
subvector
-----------
[3,4]
(1 row)
SELECT subvector('[1,2,3,4,5]'::halfvec, -1, 3);
subvector
-----------
[1]
(1 row)
SELECT subvector('[1,2,3,4,5]'::halfvec, 3, 9);
subvector
-----------
[3,4,5]
(1 row)
SELECT subvector('[1,2,3,4,5]'::halfvec, 1, 0);
ERROR: halfvec must have at least 1 dimension
SELECT subvector('[1,2,3,4,5]'::halfvec, 3, -1);
ERROR: halfvec must have at least 1 dimension
SELECT subvector('[1,2,3,4,5]'::halfvec, -1, 2);
ERROR: halfvec must have at least 1 dimension

View File

@@ -272,47 +272,47 @@ SELECT l1_distance('[3e38]'::vector, '[-3e38]');
Infinity
(1 row)
SELECT quantize_binary('[1,0,-1]');
SELECT quantize_binary('[1,0,-1]'::vector);
quantize_binary
-----------------
100
(1 row)
SELECT quantize_binary('[0,0.1,-0.2,-0.3,0.4,0.5,0.6,-0.7,0.8,-0.9,1]');
SELECT quantize_binary('[0,0.1,-0.2,-0.3,0.4,0.5,0.6,-0.7,0.8,-0.9,1]'::vector);
quantize_binary
-----------------
01001110101
(1 row)
SELECT subvector('[1,2,3,4,5]', 1, 3);
SELECT subvector('[1,2,3,4,5]'::vector, 1, 3);
subvector
-----------
[1,2,3]
(1 row)
SELECT subvector('[1,2,3,4,5]', 3, 2);
SELECT subvector('[1,2,3,4,5]'::vector, 3, 2);
subvector
-----------
[3,4]
(1 row)
SELECT subvector('[1,2,3,4,5]', -1, 3);
SELECT subvector('[1,2,3,4,5]'::vector, -1, 3);
subvector
-----------
[1]
(1 row)
SELECT subvector('[1,2,3,4,5]', 3, 9);
SELECT subvector('[1,2,3,4,5]'::vector, 3, 9);
subvector
-----------
[3,4,5]
(1 row)
SELECT subvector('[1,2,3,4,5]', 1, 0);
SELECT subvector('[1,2,3,4,5]'::vector, 1, 0);
ERROR: vector must have at least 1 dimension
SELECT subvector('[1,2,3,4,5]', 3, -1);
SELECT subvector('[1,2,3,4,5]'::vector, 3, -1);
ERROR: vector must have at least 1 dimension
SELECT subvector('[1,2,3,4,5]', -1, 2);
SELECT subvector('[1,2,3,4,5]'::vector, -1, 2);
ERROR: vector must have at least 1 dimension
SELECT avg(v) FROM unnest(ARRAY['[1,2,3]'::vector, '[3,5,7]']) v;
avg

View File

@@ -21,3 +21,14 @@ SELECT '[1,2]'::halfvec <=> '[2,4]';
SELECT l1_distance('[0,0]'::halfvec, '[3,4]');
SELECT l1_distance('[0,0]'::halfvec, '[0,1]');
SELECT l1_distance('[1,2]'::halfvec, '[3]');
SELECT quantize_binary('[1,0,-1]'::halfvec);
SELECT quantize_binary('[0,0.1,-0.2,-0.3,0.4,0.5,0.6,-0.7,0.8,-0.9,1]'::halfvec);
SELECT subvector('[1,2,3,4,5]'::halfvec, 1, 3);
SELECT subvector('[1,2,3,4,5]'::halfvec, 3, 2);
SELECT subvector('[1,2,3,4,5]'::halfvec, -1, 3);
SELECT subvector('[1,2,3,4,5]'::halfvec, 3, 9);
SELECT subvector('[1,2,3,4,5]'::halfvec, 1, 0);
SELECT subvector('[1,2,3,4,5]'::halfvec, 3, -1);
SELECT subvector('[1,2,3,4,5]'::halfvec, -1, 2);

View File

@@ -58,16 +58,16 @@ SELECT l1_distance('[0,0]'::vector, '[0,1]');
SELECT l1_distance('[1,2]'::vector, '[3]');
SELECT l1_distance('[3e38]'::vector, '[-3e38]');
SELECT quantize_binary('[1,0,-1]');
SELECT quantize_binary('[0,0.1,-0.2,-0.3,0.4,0.5,0.6,-0.7,0.8,-0.9,1]');
SELECT quantize_binary('[1,0,-1]'::vector);
SELECT quantize_binary('[0,0.1,-0.2,-0.3,0.4,0.5,0.6,-0.7,0.8,-0.9,1]'::vector);
SELECT subvector('[1,2,3,4,5]', 1, 3);
SELECT subvector('[1,2,3,4,5]', 3, 2);
SELECT subvector('[1,2,3,4,5]', -1, 3);
SELECT subvector('[1,2,3,4,5]', 3, 9);
SELECT subvector('[1,2,3,4,5]', 1, 0);
SELECT subvector('[1,2,3,4,5]', 3, -1);
SELECT subvector('[1,2,3,4,5]', -1, 2);
SELECT subvector('[1,2,3,4,5]'::vector, 1, 3);
SELECT subvector('[1,2,3,4,5]'::vector, 3, 2);
SELECT subvector('[1,2,3,4,5]'::vector, -1, 3);
SELECT subvector('[1,2,3,4,5]'::vector, 3, 9);
SELECT subvector('[1,2,3,4,5]'::vector, 1, 0);
SELECT subvector('[1,2,3,4,5]'::vector, 3, -1);
SELECT subvector('[1,2,3,4,5]'::vector, -1, 2);
SELECT avg(v) FROM unnest(ARRAY['[1,2,3]'::vector, '[3,5,7]']) v;
SELECT avg(v) FROM unnest(ARRAY['[1,2,3]'::vector, '[3,5,7]', NULL]) v;