diff --git a/CHANGELOG.md b/CHANGELOG.md index 74bcbe1..2371c7f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,6 +1,7 @@ ## 0.7.0 (unreleased) - Added `hamming_distance` function +- Added `quantize_binary` function ## 0.6.2 (2024-03-18) diff --git a/README.md b/README.md index 0683987..73f7b91 100644 --- a/README.md +++ b/README.md @@ -706,6 +706,7 @@ cosine_distance(vector, vector) → double precision | cosine distance | inner_product(vector, vector) → double precision | inner product | l2_distance(vector, vector) → double precision | Euclidean distance | l1_distance(vector, vector) → double precision | taxicab distance | 0.5.0 +quantize_binary(vector) → bit | quantize | 0.7.0 vector_dims(vector) → integer | number of dimensions | vector_norm(vector) → double precision | Euclidean norm | diff --git a/sql/vector--0.6.2--0.7.0.sql b/sql/vector--0.6.2--0.7.0.sql index 321121a..d385eaf 100644 --- a/sql/vector--0.6.2--0.7.0.sql +++ b/sql/vector--0.6.2--0.7.0.sql @@ -1,6 +1,9 @@ -- complain if script is sourced in psql, rather than via CREATE EXTENSION \echo Use "ALTER EXTENSION vector UPDATE TO '0.7.0'" to load this file. \quit +CREATE FUNCTION quantize_binary(vector) RETURNS bit + AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + CREATE FUNCTION hamming_distance(bit, bit) RETURNS float8 AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; diff --git a/sql/vector.sql b/sql/vector.sql index 3cbaadb..f37e93e 100644 --- a/sql/vector.sql +++ b/sql/vector.sql @@ -58,6 +58,9 @@ CREATE FUNCTION vector_sub(vector, vector) RETURNS vector CREATE FUNCTION vector_mul(vector, vector) RETURNS vector AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; +CREATE FUNCTION quantize_binary(vector) RETURNS bit + AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + -- private functions CREATE FUNCTION vector_lt(vector, vector) RETURNS bool diff --git a/src/hnswscan.c b/src/hnswscan.c index 969c04a..e659c14 100644 --- a/src/hnswscan.c +++ b/src/hnswscan.c @@ -80,14 +80,7 @@ GetScanValue(IndexScanDesc scan) int dimensions = GetDimensions(scan->indexRelation); if (typid == BITOID || typid == VARBITOID) - { - int len = VARBITTOTALLEN(dimensions); - VarBit *v = (VarBit *) palloc0(len); - - SET_VARSIZE(v, len); - VARBITLEN(v) = dimensions; - value = PointerGetDatum(v); - } + value = PointerGetDatum(InitBitVector(dimensions)); else value = PointerGetDatum(InitVector(dimensions)); } diff --git a/src/vector.c b/src/vector.c index 15711f8..7be4c8f 100644 --- a/src/vector.c +++ b/src/vector.c @@ -1163,6 +1163,42 @@ vector_avg(PG_FUNCTION_ARGS) PG_RETURN_POINTER(result); } +/* + * Allocate and initialize a new bit vector + */ +VarBit * +InitBitVector(int dim) +{ + VarBit *result; + int size; + + size = VARBITTOTALLEN(dim); + result = (VarBit *) palloc0(size); + SET_VARSIZE(result, size); + VARBITLEN(result) = dim; + + return result; +} + +/* + * Quantize a vector + */ +PGDLLEXPORT PG_FUNCTION_INFO_V1(quantize_binary); +Datum +quantize_binary(PG_FUNCTION_ARGS) +{ + Vector *a = PG_GETARG_VECTOR_P(0); + float *ax = a->x; + VarBit *result = InitBitVector(a->dim); + unsigned char *rx = VARBITS(result); + + /* TODO Improve */ + for (int i = 0; i < a->dim; i++) + rx[i / 8] |= (ax[i] > 0) << (7 - (i % 8)); + + PG_RETURN_VARBIT_P(result); +} + /* * Ensure same number of bits */ diff --git a/src/vector.h b/src/vector.h index e649471..d50c00f 100644 --- a/src/vector.h +++ b/src/vector.h @@ -1,6 +1,8 @@ #ifndef VECTOR_H #define VECTOR_H +#include "utils/varbit.h" + #define VECTOR_MAX_DIM 16000 #define VECTOR_SIZE(_dim) (offsetof(Vector, x) + sizeof(float)*(_dim)) @@ -17,6 +19,7 @@ typedef struct Vector } Vector; Vector *InitVector(int dim); +VarBit *InitBitVector(int dim); void PrintVector(char *msg, Vector * vector); int vector_cmp_internal(Vector * a, Vector * b); diff --git a/test/expected/functions.out b/test/expected/functions.out index 2b5e740..265a84d 100644 --- a/test/expected/functions.out +++ b/test/expected/functions.out @@ -234,6 +234,12 @@ SELECT hamming_distance(B'111', B'000'); SELECT hamming_distance(B'111', B'00'); ERROR: different bit lengths 3 and 2 +SELECT quantize_binary('[1,0,-1]'); + quantize_binary +----------------- + 100 +(1 row) + SELECT avg(v) FROM unnest(ARRAY['[1,2,3]'::vector, '[3,5,7]']) v; avg ----------- diff --git a/test/sql/functions.sql b/test/sql/functions.sql index 81fe752..42c706e 100644 --- a/test/sql/functions.sql +++ b/test/sql/functions.sql @@ -54,6 +54,8 @@ SELECT hamming_distance(B'111', B'100'); SELECT hamming_distance(B'111', B'000'); SELECT hamming_distance(B'111', B'00'); +SELECT quantize_binary('[1,0,-1]'); + SELECT avg(v) FROM unnest(ARRAY['[1,2,3]'::vector, '[3,5,7]']) v; SELECT avg(v) FROM unnest(ARRAY['[1,2,3]'::vector, '[3,5,7]', NULL]) v; SELECT avg(v) FROM unnest(ARRAY[]::vector[]) v;