diff --git a/Makefile b/Makefile index e7ae85e..406e59b 100644 --- a/Makefile +++ b/Makefile @@ -4,8 +4,8 @@ EXTVERSION = 0.7.4 MODULE_big = vector DATA = $(wildcard sql/*--*--*.sql) DATA_built = sql/$(EXTENSION)--$(EXTVERSION).sql -OBJS = src/bitutils.o src/bitvec.o src/halfutils.o src/halfvec.o src/hnsw.o src/hnswbuild.o src/hnswinsert.o src/hnswscan.o src/hnswutils.o src/hnswvacuum.o src/ivfbuild.o src/ivfflat.o src/ivfinsert.o src/ivfkmeans.o src/ivfscan.o src/ivfutils.o src/ivfvacuum.o src/sparsevec.o src/vector.o -HEADERS = src/halfvec.h src/sparsevec.h src/vector.h +OBJS = src/bitutils.o src/bitvec.o src/halfutils.o src/halfvec.o src/hnsw.o src/hnswbuild.o src/hnswinsert.o src/hnswscan.o src/hnswutils.o src/hnswvacuum.o src/ivfbuild.o src/ivfflat.o src/ivfinsert.o src/ivfkmeans.o src/ivfscan.o src/ivfutils.o src/minivec.o src/ivfvacuum.o src/sparsevec.o src/vector.o +HEADERS = src/halfvec.h src/minivec.h src/sparsevec.h src/vector.h TESTS = $(wildcard test/sql/*.sql) REGRESS = $(patsubst test/sql/%.sql,%,$(TESTS)) diff --git a/Makefile.win b/Makefile.win index c44cb1f..e75287a 100644 --- a/Makefile.win +++ b/Makefile.win @@ -2,10 +2,10 @@ EXTENSION = vector EXTVERSION = 0.7.4 DATA_built = sql\$(EXTENSION)--$(EXTVERSION).sql -OBJS = src\bitutils.obj src\bitvec.obj src\halfutils.obj src\halfvec.obj src\hnsw.obj src\hnswbuild.obj src\hnswinsert.obj src\hnswscan.obj src\hnswutils.obj src\hnswvacuum.obj src\ivfbuild.obj src\ivfflat.obj src\ivfinsert.obj src\ivfkmeans.obj src\ivfscan.obj src\ivfutils.obj src\ivfvacuum.obj src\sparsevec.obj src\vector.obj -HEADERS = src\halfvec.h src\sparsevec.h src\vector.h +OBJS = src\bitutils.obj src\bitvec.obj src\halfutils.obj src\halfvec.obj src\hnsw.obj src\hnswbuild.obj src\hnswinsert.obj src\hnswscan.obj src\hnswutils.obj src\hnswvacuum.obj src\ivfbuild.obj src\ivfflat.obj src\ivfinsert.obj src\ivfkmeans.obj src\ivfscan.obj src\ivfutils.obj src\ivfvacuum.obj src\minivec.obj src\sparsevec.obj src\vector.obj +HEADERS = src\halfvec.h src\minivec.h src\sparsevec.h src\vector.h -REGRESS = bit btree cast copy halfvec hnsw_bit hnsw_halfvec hnsw_sparsevec hnsw_vector ivfflat_bit ivfflat_halfvec ivfflat_vector sparsevec vector_type +REGRESS = bit btree cast copy halfvec hnsw_bit hnsw_halfvec hnsw_sparsevec hnsw_vector ivfflat_bit ivfflat_halfvec ivfflat_vector minivec sparsevec vector_type REGRESS_OPTS = --inputdir=test --load-extension=$(EXTENSION) # For /arch flags diff --git a/README.md b/README.md index 7439964..dd9d260 100644 --- a/README.md +++ b/README.md @@ -934,6 +934,37 @@ Function | Description | Added avg(halfvec) → halfvec | average | 0.7.0 sum(halfvec) → halfvec | sum | 0.7.0 +### Minivec Type + +Each mini vector takes `dimensions + 8` bytes of storage. Each element is a E4M3 8-bit floating-point number, and all elements must be finite (no `NaN`). Mini vectors can have up to 16,000 dimensions. + +### Minivec Operators + +Operator | Description | Added +--- | --- | --- +\+ | element-wise addition | 0.8.0 +\- | element-wise subtraction | 0.8.0 +\* | element-wise multiplication | 0.8.0 +\|\| | concatenate | 0.8.0 +<-> | Euclidean distance | 0.8.0 +<#> | negative inner product | 0.8.0 +<=> | cosine distance | 0.8.0 +<+> | taxicab distance | 0.8.0 + +### Minivec Functions + +Function | Description | Added +--- | --- | --- +binary_quantize(minivec) → bit | binary quantize | 0.8.0 +cosine_distance(minivec, minivec) → double precision | cosine distance | 0.8.0 +inner_product(minivec, minivec) → double precision | inner product | 0.8.0 +l1_distance(minivec, minivec) → double precision | taxicab distance | 0.8.0 +l2_distance(minivec, minivec) → double precision | Euclidean distance | 0.8.0 +l2_norm(minivec) → double precision | Euclidean norm | 0.8.0 +l2_normalize(minivec) → minivec | Normalize with Euclidean norm | 0.8.0 +subvector(minivec, integer, integer) → minivec | subvector | 0.8.0 +vector_dims(minivec) → integer | number of dimensions | 0.8.0 + ### Bit Type Each bit vector takes `dimensions / 8 + 8` bytes of storage. See the [Postgres docs](https://www.postgresql.org/docs/current/datatype-bit.html) for more info. diff --git a/sql/vector--0.7.4--0.8.0.sql b/sql/vector--0.7.4--0.8.0.sql index e00348d..84b711e 100644 --- a/sql/vector--0.7.4--0.8.0.sql +++ b/sql/vector--0.7.4--0.8.0.sql @@ -1,6 +1,8 @@ -- complain if script is sourced in psql, rather than via CREATE EXTENSION \echo Use "ALTER EXTENSION vector UPDATE TO '0.8.0'" to load this file. \quit +-- TODO minivec functions + CREATE FUNCTION array_to_sparsevec(integer[], integer, boolean) RETURNS sparsevec AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; diff --git a/sql/vector.sql b/sql/vector.sql index 7fc3671..15f484c 100644 --- a/sql/vector.sql +++ b/sql/vector.sql @@ -272,6 +272,9 @@ CREATE FUNCTION ivfflat_bit_support(internal) RETURNS internal CREATE FUNCTION hnsw_halfvec_support(internal) RETURNS internal AS 'MODULE_PATHNAME' LANGUAGE C; +CREATE FUNCTION hnsw_minivec_support(internal) RETURNS internal + AS 'MODULE_PATHNAME' LANGUAGE C; + CREATE FUNCTION hnsw_bit_support(internal) RETURNS internal AS 'MODULE_PATHNAME' LANGUAGE C; @@ -647,6 +650,268 @@ CREATE OPERATOR CLASS halfvec_l1_ops FUNCTION 1 l1_distance(halfvec, halfvec), FUNCTION 3 hnsw_halfvec_support(internal); +-- minivec type + +CREATE TYPE minivec; + +CREATE FUNCTION minivec_in(cstring, oid, integer) RETURNS minivec + AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + +CREATE FUNCTION minivec_out(minivec) RETURNS cstring + AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + +CREATE FUNCTION minivec_typmod_in(cstring[]) RETURNS integer + AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + +CREATE FUNCTION minivec_recv(internal, oid, integer) RETURNS minivec + AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + +CREATE FUNCTION minivec_send(minivec) RETURNS bytea + AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + +CREATE TYPE minivec ( + INPUT = minivec_in, + OUTPUT = minivec_out, + TYPMOD_IN = minivec_typmod_in, + RECEIVE = minivec_recv, + SEND = minivec_send, + STORAGE = external +); + +-- minivec functions + +CREATE FUNCTION l2_distance(minivec, minivec) RETURNS float8 + AS 'MODULE_PATHNAME', 'minivec_l2_distance' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + +CREATE FUNCTION inner_product(minivec, minivec) RETURNS float8 + AS 'MODULE_PATHNAME', 'minivec_inner_product' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + +CREATE FUNCTION cosine_distance(minivec, minivec) RETURNS float8 + AS 'MODULE_PATHNAME', 'minivec_cosine_distance' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + +CREATE FUNCTION l1_distance(minivec, minivec) RETURNS float8 + AS 'MODULE_PATHNAME', 'minivec_l1_distance' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + +CREATE FUNCTION vector_dims(minivec) RETURNS integer + AS 'MODULE_PATHNAME', 'minivec_vector_dims' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + +CREATE FUNCTION l2_norm(minivec) RETURNS float8 + AS 'MODULE_PATHNAME', 'minivec_l2_norm' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + +CREATE FUNCTION l2_normalize(minivec) RETURNS minivec + AS 'MODULE_PATHNAME', 'minivec_l2_normalize' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + +CREATE FUNCTION binary_quantize(minivec) RETURNS bit + AS 'MODULE_PATHNAME', 'minivec_binary_quantize' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + +CREATE FUNCTION subvector(minivec, int, int) RETURNS minivec + AS 'MODULE_PATHNAME', 'minivec_subvector' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + +-- minivec private functions + +CREATE FUNCTION minivec_add(minivec, minivec) RETURNS minivec + AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + +CREATE FUNCTION minivec_sub(minivec, minivec) RETURNS minivec + AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + +CREATE FUNCTION minivec_mul(minivec, minivec) RETURNS minivec + AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + +CREATE FUNCTION minivec_concat(minivec, minivec) RETURNS minivec + AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + +CREATE FUNCTION minivec_lt(minivec, minivec) RETURNS bool + AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + +CREATE FUNCTION minivec_le(minivec, minivec) RETURNS bool + AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + +CREATE FUNCTION minivec_eq(minivec, minivec) RETURNS bool + AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + +CREATE FUNCTION minivec_ne(minivec, minivec) RETURNS bool + AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + +CREATE FUNCTION minivec_ge(minivec, minivec) RETURNS bool + AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + +CREATE FUNCTION minivec_gt(minivec, minivec) RETURNS bool + AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + +CREATE FUNCTION minivec_cmp(minivec, minivec) RETURNS int4 + AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + +CREATE FUNCTION minivec_l2_squared_distance(minivec, minivec) RETURNS float8 + AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + +CREATE FUNCTION minivec_negative_inner_product(minivec, minivec) RETURNS float8 + AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + +-- minivec cast functions + +CREATE FUNCTION minivec(minivec, integer, boolean) RETURNS minivec + AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + +CREATE FUNCTION minivec_to_vector(minivec, integer, boolean) RETURNS vector + AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + +CREATE FUNCTION vector_to_minivec(vector, integer, boolean) RETURNS minivec + AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + +CREATE FUNCTION array_to_minivec(integer[], integer, boolean) RETURNS minivec + AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + +CREATE FUNCTION array_to_minivec(real[], integer, boolean) RETURNS minivec + AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + +CREATE FUNCTION array_to_minivec(double precision[], integer, boolean) RETURNS minivec + AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + +CREATE FUNCTION array_to_minivec(numeric[], integer, boolean) RETURNS minivec + AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + +CREATE FUNCTION minivec_to_float4(minivec, integer, boolean) RETURNS real[] + AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + +-- minivec casts + +CREATE CAST (minivec AS minivec) + WITH FUNCTION minivec(minivec, integer, boolean) AS IMPLICIT; + +CREATE CAST (minivec AS vector) + WITH FUNCTION minivec_to_vector(minivec, integer, boolean) AS ASSIGNMENT; + +CREATE CAST (vector AS minivec) + WITH FUNCTION vector_to_minivec(vector, integer, boolean) AS IMPLICIT; + +CREATE CAST (minivec AS real[]) + WITH FUNCTION minivec_to_float4(minivec, integer, boolean) AS ASSIGNMENT; + +CREATE CAST (integer[] AS minivec) + WITH FUNCTION array_to_minivec(integer[], integer, boolean) AS ASSIGNMENT; + +CREATE CAST (real[] AS minivec) + WITH FUNCTION array_to_minivec(real[], integer, boolean) AS ASSIGNMENT; + +CREATE CAST (double precision[] AS minivec) + WITH FUNCTION array_to_minivec(double precision[], integer, boolean) AS ASSIGNMENT; + +CREATE CAST (numeric[] AS minivec) + WITH FUNCTION array_to_minivec(numeric[], integer, boolean) AS ASSIGNMENT; + +-- minivec operators + +CREATE OPERATOR <-> ( + LEFTARG = minivec, RIGHTARG = minivec, PROCEDURE = l2_distance, + COMMUTATOR = '<->' +); + +CREATE OPERATOR <#> ( + LEFTARG = minivec, RIGHTARG = minivec, PROCEDURE = minivec_negative_inner_product, + COMMUTATOR = '<#>' +); + +CREATE OPERATOR <=> ( + LEFTARG = minivec, RIGHTARG = minivec, PROCEDURE = cosine_distance, + COMMUTATOR = '<=>' +); + +CREATE OPERATOR <+> ( + LEFTARG = minivec, RIGHTARG = minivec, PROCEDURE = l1_distance, + COMMUTATOR = '<+>' +); + +CREATE OPERATOR + ( + LEFTARG = minivec, RIGHTARG = minivec, PROCEDURE = minivec_add, + COMMUTATOR = + +); + +CREATE OPERATOR - ( + LEFTARG = minivec, RIGHTARG = minivec, PROCEDURE = minivec_sub +); + +CREATE OPERATOR * ( + LEFTARG = minivec, RIGHTARG = minivec, PROCEDURE = minivec_mul, + COMMUTATOR = * +); + +CREATE OPERATOR || ( + LEFTARG = minivec, RIGHTARG = minivec, PROCEDURE = minivec_concat +); + +CREATE OPERATOR < ( + LEFTARG = minivec, RIGHTARG = minivec, PROCEDURE = minivec_lt, + COMMUTATOR = > , NEGATOR = >= , + RESTRICT = scalarltsel, JOIN = scalarltjoinsel +); + +CREATE OPERATOR <= ( + LEFTARG = minivec, RIGHTARG = minivec, PROCEDURE = minivec_le, + COMMUTATOR = >= , NEGATOR = > , + RESTRICT = scalarlesel, JOIN = scalarlejoinsel +); + +CREATE OPERATOR = ( + LEFTARG = minivec, RIGHTARG = minivec, PROCEDURE = minivec_eq, + COMMUTATOR = = , NEGATOR = <> , + RESTRICT = eqsel, JOIN = eqjoinsel +); + +CREATE OPERATOR <> ( + LEFTARG = minivec, RIGHTARG = minivec, PROCEDURE = minivec_ne, + COMMUTATOR = <> , NEGATOR = = , + RESTRICT = eqsel, JOIN = eqjoinsel +); + +CREATE OPERATOR >= ( + LEFTARG = minivec, RIGHTARG = minivec, PROCEDURE = minivec_ge, + COMMUTATOR = <= , NEGATOR = < , + RESTRICT = scalargesel, JOIN = scalargejoinsel +); + +CREATE OPERATOR > ( + LEFTARG = minivec, RIGHTARG = minivec, PROCEDURE = minivec_gt, + COMMUTATOR = < , NEGATOR = <= , + RESTRICT = scalargtsel, JOIN = scalargtjoinsel +); + +-- minivec op classes + +CREATE OPERATOR CLASS minivec_ops + DEFAULT FOR TYPE minivec USING btree AS + OPERATOR 1 < , + OPERATOR 2 <= , + OPERATOR 3 = , + OPERATOR 4 >= , + OPERATOR 5 > , + FUNCTION 1 minivec_cmp(minivec, minivec); + +CREATE OPERATOR CLASS minivec_l2_ops + FOR TYPE minivec USING hnsw AS + OPERATOR 1 <-> (minivec, minivec) FOR ORDER BY float_ops, + FUNCTION 1 minivec_l2_squared_distance(minivec, minivec), + FUNCTION 3 hnsw_minivec_support(internal); + +CREATE OPERATOR CLASS minivec_ip_ops + FOR TYPE minivec USING hnsw AS + OPERATOR 1 <#> (minivec, minivec) FOR ORDER BY float_ops, + FUNCTION 1 minivec_negative_inner_product(minivec, minivec), + FUNCTION 3 hnsw_minivec_support(internal); + +CREATE OPERATOR CLASS minivec_cosine_ops + FOR TYPE minivec USING hnsw AS + OPERATOR 1 <=> (minivec, minivec) FOR ORDER BY float_ops, + FUNCTION 1 minivec_negative_inner_product(minivec, minivec), + FUNCTION 2 l2_norm(minivec), + FUNCTION 3 hnsw_minivec_support(internal); + +CREATE OPERATOR CLASS minivec_l1_ops + FOR TYPE minivec USING hnsw AS + OPERATOR 1 <+> (minivec, minivec) FOR ORDER BY float_ops, + FUNCTION 1 l1_distance(minivec, minivec), + FUNCTION 3 hnsw_minivec_support(internal); + -- bit functions CREATE FUNCTION hamming_distance(bit, bit) RETURNS float8 diff --git a/src/hnswutils.c b/src/hnswutils.c index ac1e7de..0b465ee 100644 --- a/src/hnswutils.c +++ b/src/hnswutils.c @@ -1327,6 +1327,7 @@ HnswFindElementNeighbors(char *base, HnswElement element, HnswElement entryPoint PGDLLEXPORT Datum l2_normalize(PG_FUNCTION_ARGS); PGDLLEXPORT Datum halfvec_l2_normalize(PG_FUNCTION_ARGS); +PGDLLEXPORT Datum minivec_l2_normalize(PG_FUNCTION_ARGS); PGDLLEXPORT Datum sparsevec_l2_normalize(PG_FUNCTION_ARGS); static void @@ -1375,6 +1376,19 @@ hnsw_halfvec_support(PG_FUNCTION_ARGS) PG_RETURN_POINTER(&typeInfo); }; +FUNCTION_PREFIX PG_FUNCTION_INFO_V1(hnsw_minivec_support); +Datum +hnsw_minivec_support(PG_FUNCTION_ARGS) +{ + static const HnswTypeInfo typeInfo = { + .maxDimensions = HNSW_MAX_DIM * 4, + .normalize = minivec_l2_normalize, + .checkValue = NULL + }; + + PG_RETURN_POINTER(&typeInfo); +}; + FUNCTION_PREFIX PG_FUNCTION_INFO_V1(hnsw_bit_support); Datum hnsw_bit_support(PG_FUNCTION_ARGS) diff --git a/src/minivec.c b/src/minivec.c new file mode 100644 index 0000000..f73ab3c --- /dev/null +++ b/src/minivec.c @@ -0,0 +1,1075 @@ +#include "postgres.h" + +#include + +#include "bitvec.h" +#include "catalog/pg_type.h" +#include "common/shortest_dec.h" +#include "fmgr.h" +#include "minivec.h" +#include "lib/stringinfo.h" +#include "libpq/pqformat.h" +#include "port.h" /* for strtof() */ +#include "sparsevec.h" +#include "utils/array.h" +#include "utils/builtins.h" +#include "utils/float.h" +#include "utils/lsyscache.h" +#include "utils/numeric.h" +#include "vector.h" + +/* + * Ensure same dimensions + */ +static inline void +CheckDims(MiniVector * a, MiniVector * b) +{ + if (a->dim != b->dim) + ereport(ERROR, + (errcode(ERRCODE_DATA_EXCEPTION), + errmsg("different minivec dimensions %d and %d", a->dim, b->dim))); +} + +/* + * Ensure expected dimensions + */ +static inline void +CheckExpectedDim(int32 typmod, int dim) +{ + if (typmod != -1 && typmod != dim) + ereport(ERROR, + (errcode(ERRCODE_DATA_EXCEPTION), + errmsg("expected %d dimensions, not %d", typmod, dim))); +} + +/* + * Ensure valid dimensions + */ +static inline void +CheckDim(int dim) +{ + if (dim < 1) + ereport(ERROR, + (errcode(ERRCODE_DATA_EXCEPTION), + errmsg("minivec must have at least 1 dimension"))); + + if (dim > MINIVEC_MAX_DIM) + ereport(ERROR, + (errcode(ERRCODE_PROGRAM_LIMIT_EXCEEDED), + errmsg("minivec cannot have more than %d dimensions", MINIVEC_MAX_DIM))); +} + +/* + * Ensure finite element + */ +static inline void +CheckElement(fp8 value) +{ + if (Fp8IsNan(value)) + ereport(ERROR, + (errcode(ERRCODE_DATA_EXCEPTION), + errmsg("NaN not allowed in minivec"))); +} + +/* + * Allocate and initialize a new fp8 vector + */ +MiniVector * +InitMiniVector(int dim) +{ + MiniVector *result; + int size; + + size = MINIVEC_SIZE(dim); + result = (MiniVector *) palloc0(size); + SET_VARSIZE(result, size); + result->dim = dim; + + return result; +} + +/* + * Check for whitespace, since array_isspace() is static + */ +static inline bool +minivec_isspace(char ch) +{ + if (ch == ' ' || + ch == '\t' || + ch == '\n' || + ch == '\r' || + ch == '\v' || + ch == '\f') + return true; + return false; +} + +/* + * Convert textual representation to internal representation + */ +FUNCTION_PREFIX PG_FUNCTION_INFO_V1(minivec_in); +Datum +minivec_in(PG_FUNCTION_ARGS) +{ + char *lit = PG_GETARG_CSTRING(0); + int32 typmod = PG_GETARG_INT32(2); + fp8 x[MINIVEC_MAX_DIM]; + int dim = 0; + char *pt = lit; + MiniVector *result; + + while (minivec_isspace(*pt)) + pt++; + + if (*pt != '[') + ereport(ERROR, + (errcode(ERRCODE_INVALID_TEXT_REPRESENTATION), + errmsg("invalid input syntax for type minivec: \"%s\"", lit), + errdetail("Vector contents must start with \"[\"."))); + + pt++; + + while (minivec_isspace(*pt)) + pt++; + + if (*pt == ']') + ereport(ERROR, + (errcode(ERRCODE_DATA_EXCEPTION), + errmsg("minivec must have at least 1 dimension"))); + + for (;;) + { + float val; + char *stringEnd; + + if (dim == MINIVEC_MAX_DIM) + ereport(ERROR, + (errcode(ERRCODE_PROGRAM_LIMIT_EXCEEDED), + errmsg("minivec cannot have more than %d dimensions", MINIVEC_MAX_DIM))); + + while (minivec_isspace(*pt)) + pt++; + + /* Check for empty string like float4in */ + if (*pt == '\0') + ereport(ERROR, + (errcode(ERRCODE_INVALID_TEXT_REPRESENTATION), + errmsg("invalid input syntax for type minivec: \"%s\"", lit))); + + errno = 0; + + /* Postgres sets LC_NUMERIC to C on startup */ + val = strtof(pt, &stringEnd); + + if (stringEnd == pt) + ereport(ERROR, + (errcode(ERRCODE_INVALID_TEXT_REPRESENTATION), + errmsg("invalid input syntax for type minivec: \"%s\"", lit))); + + x[dim] = Float4ToFp8Unchecked(val); + + /* Check for range error like float4in */ + if ((errno == ERANGE && isinf(val)) || (Fp8IsNan(x[dim]) && !isnan(val))) + ereport(ERROR, + (errcode(ERRCODE_NUMERIC_VALUE_OUT_OF_RANGE), + errmsg("\"%s\" is out of range for type minivec", pnstrdup(pt, stringEnd - pt)))); + + CheckElement(x[dim]); + dim++; + + pt = stringEnd; + + while (minivec_isspace(*pt)) + pt++; + + if (*pt == ',') + pt++; + else if (*pt == ']') + { + pt++; + break; + } + else + ereport(ERROR, + (errcode(ERRCODE_INVALID_TEXT_REPRESENTATION), + errmsg("invalid input syntax for type minivec: \"%s\"", lit))); + } + + /* Only whitespace is allowed after the closing brace */ + while (minivec_isspace(*pt)) + pt++; + + if (*pt != '\0') + ereport(ERROR, + (errcode(ERRCODE_INVALID_TEXT_REPRESENTATION), + errmsg("invalid input syntax for type minivec: \"%s\"", lit), + errdetail("Junk after closing right brace."))); + + CheckDim(dim); + CheckExpectedDim(typmod, dim); + + result = InitMiniVector(dim); + for (int i = 0; i < dim; i++) + result->x[i] = x[i]; + + PG_RETURN_POINTER(result); +} + +#define AppendChar(ptr, c) (*(ptr)++ = (c)) +#define AppendFloat(ptr, f) ((ptr) += float_to_shortest_decimal_bufn((f), (ptr))) + +/* + * Convert internal representation to textual representation + */ +FUNCTION_PREFIX PG_FUNCTION_INFO_V1(minivec_out); +Datum +minivec_out(PG_FUNCTION_ARGS) +{ + MiniVector *vector = PG_GETARG_MINIVEC_P(0); + int dim = vector->dim; + char *buf; + char *ptr; + + /* + * Need: + * + * dim * (FLOAT_SHORTEST_DECIMAL_LEN - 1) bytes for + * float_to_shortest_decimal_bufn + * + * dim - 1 bytes for separator + * + * 3 bytes for [, ], and \0 + */ + buf = (char *) palloc(FLOAT_SHORTEST_DECIMAL_LEN * dim + 2); + ptr = buf; + + AppendChar(ptr, '['); + + for (int i = 0; i < dim; i++) + { + if (i > 0) + AppendChar(ptr, ','); + + /* + * Use shortest decimal representation of single-precision float for + * simplicity + */ + AppendFloat(ptr, Fp8ToFloat4(vector->x[i])); + } + + AppendChar(ptr, ']'); + *ptr = '\0'; + + PG_FREE_IF_COPY(vector, 0); + PG_RETURN_CSTRING(buf); +} + +/* + * Convert type modifier + */ +FUNCTION_PREFIX PG_FUNCTION_INFO_V1(minivec_typmod_in); +Datum +minivec_typmod_in(PG_FUNCTION_ARGS) +{ + ArrayType *ta = PG_GETARG_ARRAYTYPE_P(0); + int32 *tl; + int n; + + tl = ArrayGetIntegerTypmods(ta, &n); + + if (n != 1) + ereport(ERROR, + (errcode(ERRCODE_INVALID_PARAMETER_VALUE), + errmsg("invalid type modifier"))); + + if (*tl < 1) + ereport(ERROR, + (errcode(ERRCODE_INVALID_PARAMETER_VALUE), + errmsg("dimensions for type minivec must be at least 1"))); + + if (*tl > MINIVEC_MAX_DIM) + ereport(ERROR, + (errcode(ERRCODE_INVALID_PARAMETER_VALUE), + errmsg("dimensions for type minivec cannot exceed %d", MINIVEC_MAX_DIM))); + + PG_RETURN_INT32(*tl); +} + +/* + * Convert external binary representation to internal representation + */ +FUNCTION_PREFIX PG_FUNCTION_INFO_V1(minivec_recv); +Datum +minivec_recv(PG_FUNCTION_ARGS) +{ + StringInfo buf = (StringInfo) PG_GETARG_POINTER(0); + int32 typmod = PG_GETARG_INT32(2); + MiniVector *result; + int16 dim; + int16 unused; + + dim = pq_getmsgint(buf, sizeof(int16)); + unused = pq_getmsgint(buf, sizeof(int16)); + + CheckDim(dim); + CheckExpectedDim(typmod, dim); + + if (unused != 0) + ereport(ERROR, + (errcode(ERRCODE_DATA_EXCEPTION), + errmsg("expected unused to be 0, not %d", unused))); + + result = InitMiniVector(dim); + for (int i = 0; i < dim; i++) + { + result->x[i] = pq_getmsgint(buf, sizeof(uint8)); + CheckElement(result->x[i]); + } + + PG_RETURN_POINTER(result); +} + +/* + * Convert internal representation to the external binary representation + */ +FUNCTION_PREFIX PG_FUNCTION_INFO_V1(minivec_send); +Datum +minivec_send(PG_FUNCTION_ARGS) +{ + MiniVector *vec = PG_GETARG_MINIVEC_P(0); + StringInfoData buf; + + pq_begintypsend(&buf); + pq_sendint(&buf, vec->dim, sizeof(int16)); + pq_sendint(&buf, vec->unused, sizeof(int16)); + for (int i = 0; i < vec->dim; i++) + pq_sendint8(&buf, vec->x[i]); + + PG_RETURN_BYTEA_P(pq_endtypsend(&buf)); +} + +/* + * Convert fp8 vector to fp8 vector + * This is needed to check the type modifier + */ +FUNCTION_PREFIX PG_FUNCTION_INFO_V1(minivec); +Datum +minivec(PG_FUNCTION_ARGS) +{ + MiniVector *vec = PG_GETARG_MINIVEC_P(0); + int32 typmod = PG_GETARG_INT32(1); + + CheckExpectedDim(typmod, vec->dim); + + PG_RETURN_POINTER(vec); +} + +/* + * Convert array to fp8 vector + */ +FUNCTION_PREFIX PG_FUNCTION_INFO_V1(array_to_minivec); +Datum +array_to_minivec(PG_FUNCTION_ARGS) +{ + ArrayType *array = PG_GETARG_ARRAYTYPE_P(0); + int32 typmod = PG_GETARG_INT32(1); + MiniVector *result; + int16 typlen; + bool typbyval; + char typalign; + Datum *elemsp; + int nelemsp; + + if (ARR_NDIM(array) > 1) + ereport(ERROR, + (errcode(ERRCODE_DATA_EXCEPTION), + errmsg("array must be 1-D"))); + + if (ARR_HASNULL(array) && array_contains_nulls(array)) + ereport(ERROR, + (errcode(ERRCODE_NULL_VALUE_NOT_ALLOWED), + errmsg("array must not contain nulls"))); + + get_typlenbyvalalign(ARR_ELEMTYPE(array), &typlen, &typbyval, &typalign); + deconstruct_array(array, ARR_ELEMTYPE(array), typlen, typbyval, typalign, &elemsp, NULL, &nelemsp); + + CheckDim(nelemsp); + CheckExpectedDim(typmod, nelemsp); + + result = InitMiniVector(nelemsp); + + if (ARR_ELEMTYPE(array) == INT4OID) + { + for (int i = 0; i < nelemsp; i++) + result->x[i] = Float4ToFp8(DatumGetInt32(elemsp[i])); + } + else if (ARR_ELEMTYPE(array) == FLOAT8OID) + { + for (int i = 0; i < nelemsp; i++) + result->x[i] = Float4ToFp8(DatumGetFloat8(elemsp[i])); + } + else if (ARR_ELEMTYPE(array) == FLOAT4OID) + { + for (int i = 0; i < nelemsp; i++) + result->x[i] = Float4ToFp8(DatumGetFloat4(elemsp[i])); + } + else if (ARR_ELEMTYPE(array) == NUMERICOID) + { + for (int i = 0; i < nelemsp; i++) + result->x[i] = Float4ToFp8(DatumGetFloat4(DirectFunctionCall1(numeric_float4, elemsp[i]))); + } + else + { + ereport(ERROR, + (errcode(ERRCODE_DATA_EXCEPTION), + errmsg("unsupported array type"))); + } + + /* + * Free allocation from deconstruct_array. Do not free individual elements + * when pass-by-reference since they point to original array. + */ + pfree(elemsp); + + /* Check elements */ + for (int i = 0; i < result->dim; i++) + CheckElement(result->x[i]); + + PG_RETURN_POINTER(result); +} + +/* + * Convert fp8 vector to float4[] + */ +FUNCTION_PREFIX PG_FUNCTION_INFO_V1(minivec_to_float4); +Datum +minivec_to_float4(PG_FUNCTION_ARGS) +{ + MiniVector *vec = PG_GETARG_MINIVEC_P(0); + Datum *datums; + ArrayType *result; + + datums = (Datum *) palloc(sizeof(Datum) * vec->dim); + + for (int i = 0; i < vec->dim; i++) + datums[i] = Float4GetDatum(Fp8ToFloat4(vec->x[i])); + + /* Use TYPALIGN_INT for float4 */ + result = construct_array(datums, vec->dim, FLOAT4OID, sizeof(float4), true, TYPALIGN_INT); + + pfree(datums); + + PG_RETURN_POINTER(result); +} + +/* + * Convert vector to fp8 vector + */ +FUNCTION_PREFIX PG_FUNCTION_INFO_V1(vector_to_minivec); +Datum +vector_to_minivec(PG_FUNCTION_ARGS) +{ + Vector *vec = PG_GETARG_VECTOR_P(0); + int32 typmod = PG_GETARG_INT32(1); + MiniVector *result; + + CheckDim(vec->dim); + CheckExpectedDim(typmod, vec->dim); + + result = InitMiniVector(vec->dim); + + for (int i = 0; i < vec->dim; i++) + result->x[i] = Float4ToFp8(vec->x[i]); + + PG_RETURN_POINTER(result); +} + +static float +MinivecL2SquaredDistance(int dim, fp8 * ax, fp8 * bx) +{ + float distance = 0.0; + + for (int i = 0; i < dim; i++) + { + float diff = Fp8ToFloat4(ax[i]) - Fp8ToFloat4(bx[i]); + + distance += diff * diff; + } + + return distance; +} + +/* + * Get the L2 distance between fp8 vectors + */ +FUNCTION_PREFIX PG_FUNCTION_INFO_V1(minivec_l2_distance); +Datum +minivec_l2_distance(PG_FUNCTION_ARGS) +{ + MiniVector *a = PG_GETARG_MINIVEC_P(0); + MiniVector *b = PG_GETARG_MINIVEC_P(1); + + CheckDims(a, b); + + PG_RETURN_FLOAT8(sqrt((double) MinivecL2SquaredDistance(a->dim, a->x, b->x))); +} + +/* + * Get the L2 squared distance between fp8 vectors + */ +FUNCTION_PREFIX PG_FUNCTION_INFO_V1(minivec_l2_squared_distance); +Datum +minivec_l2_squared_distance(PG_FUNCTION_ARGS) +{ + MiniVector *a = PG_GETARG_MINIVEC_P(0); + MiniVector *b = PG_GETARG_MINIVEC_P(1); + + CheckDims(a, b); + + PG_RETURN_FLOAT8((double) MinivecL2SquaredDistance(a->dim, a->x, b->x)); +} + +static float +MinivecInnerProduct(int dim, fp8 * ax, fp8 * bx) +{ + float distance = 0.0; + + for (int i = 0; i < dim; i++) + distance += Fp8ToFloat4(ax[i]) * Fp8ToFloat4(bx[i]); + + return distance; +} + +/* + * Get the inner product of two fp8 vectors + */ +FUNCTION_PREFIX PG_FUNCTION_INFO_V1(minivec_inner_product); +Datum +minivec_inner_product(PG_FUNCTION_ARGS) +{ + MiniVector *a = PG_GETARG_MINIVEC_P(0); + MiniVector *b = PG_GETARG_MINIVEC_P(1); + + CheckDims(a, b); + + PG_RETURN_FLOAT8((double) MinivecInnerProduct(a->dim, a->x, b->x)); +} + +/* + * Get the negative inner product of two fp8 vectors + */ +FUNCTION_PREFIX PG_FUNCTION_INFO_V1(minivec_negative_inner_product); +Datum +minivec_negative_inner_product(PG_FUNCTION_ARGS) +{ + MiniVector *a = PG_GETARG_MINIVEC_P(0); + MiniVector *b = PG_GETARG_MINIVEC_P(1); + + CheckDims(a, b); + + PG_RETURN_FLOAT8((double) -MinivecInnerProduct(a->dim, a->x, b->x)); +} + +static double +MinivecCosineSimilarity(int dim, fp8 * ax, fp8 * bx) +{ + float similarity = 0.0; + float norma = 0.0; + float normb = 0.0; + + for (int i = 0; i < dim; i++) + { + float axi = Fp8ToFloat4(ax[i]); + float bxi = Fp8ToFloat4(bx[i]); + + similarity += axi * bxi; + norma += axi * axi; + normb += bxi * bxi; + } + + /* Use sqrt(a * b) over sqrt(a) * sqrt(b) */ + return (double) similarity / sqrt((double) norma * (double) normb); +} + +/* + * Get the cosine distance between two fp8 vectors + */ +FUNCTION_PREFIX PG_FUNCTION_INFO_V1(minivec_cosine_distance); +Datum +minivec_cosine_distance(PG_FUNCTION_ARGS) +{ + MiniVector *a = PG_GETARG_MINIVEC_P(0); + MiniVector *b = PG_GETARG_MINIVEC_P(1); + double similarity; + + CheckDims(a, b); + + similarity = MinivecCosineSimilarity(a->dim, a->x, b->x); + +#ifdef _MSC_VER + /* /fp:fast may not propagate NaN */ + if (isnan(similarity)) + PG_RETURN_FLOAT8(NAN); +#endif + + /* Keep in range */ + if (similarity > 1) + similarity = 1; + else if (similarity < -1) + similarity = -1; + + PG_RETURN_FLOAT8(1 - similarity); +} + +/* + * Get the distance for spherical k-means + * Currently uses angular distance since needs to satisfy triangle inequality + * Assumes inputs are unit vectors (skips norm) + */ +FUNCTION_PREFIX PG_FUNCTION_INFO_V1(minivec_spherical_distance); +Datum +minivec_spherical_distance(PG_FUNCTION_ARGS) +{ + MiniVector *a = PG_GETARG_MINIVEC_P(0); + MiniVector *b = PG_GETARG_MINIVEC_P(1); + double distance; + + CheckDims(a, b); + + distance = (double) MinivecInnerProduct(a->dim, a->x, b->x); + + /* Prevent NaN with acos with loss of precision */ + if (distance > 1) + distance = 1; + else if (distance < -1) + distance = -1; + + PG_RETURN_FLOAT8(acos(distance) / M_PI); +} + +static float +MinivecL1Distance(int dim, fp8 * ax, fp8 * bx) +{ + float distance = 0.0; + + /* Auto-vectorized */ + for (int i = 0; i < dim; i++) + distance += fabsf(Fp8ToFloat4(ax[i]) - Fp8ToFloat4(bx[i])); + + return distance; +} + +/* + * Get the L1 distance between two fp8 vectors + */ +FUNCTION_PREFIX PG_FUNCTION_INFO_V1(minivec_l1_distance); +Datum +minivec_l1_distance(PG_FUNCTION_ARGS) +{ + MiniVector *a = PG_GETARG_MINIVEC_P(0); + MiniVector *b = PG_GETARG_MINIVEC_P(1); + + CheckDims(a, b); + + PG_RETURN_FLOAT8((double) MinivecL1Distance(a->dim, a->x, b->x)); +} + +/* + * Get the dimensions of a fp8 vector + */ +FUNCTION_PREFIX PG_FUNCTION_INFO_V1(minivec_vector_dims); +Datum +minivec_vector_dims(PG_FUNCTION_ARGS) +{ + MiniVector *a = PG_GETARG_MINIVEC_P(0); + + PG_RETURN_INT32(a->dim); +} + +/* + * Get the L2 norm of a fp8 vector + */ +FUNCTION_PREFIX PG_FUNCTION_INFO_V1(minivec_l2_norm); +Datum +minivec_l2_norm(PG_FUNCTION_ARGS) +{ + MiniVector *a = PG_GETARG_MINIVEC_P(0); + fp8 *ax = a->x; + double norm = 0.0; + + /* Auto-vectorized */ + for (int i = 0; i < a->dim; i++) + { + double axi = (double) Fp8ToFloat4(ax[i]); + + norm += axi * axi; + } + + PG_RETURN_FLOAT8(sqrt(norm)); +} + +/* + * Normalize a fp8 vector with the L2 norm + */ +FUNCTION_PREFIX PG_FUNCTION_INFO_V1(minivec_l2_normalize); +Datum +minivec_l2_normalize(PG_FUNCTION_ARGS) +{ + MiniVector *a = PG_GETARG_MINIVEC_P(0); + fp8 *ax = a->x; + double norm = 0; + MiniVector *result; + fp8 *rx; + + result = InitMiniVector(a->dim); + rx = result->x; + + /* Auto-vectorized */ + for (int i = 0; i < a->dim; i++) + norm += (double) Fp8ToFloat4(ax[i]) * (double) Fp8ToFloat4(ax[i]); + + norm = sqrt(norm); + + /* Return zero vector for zero norm */ + if (norm > 0) + { + for (int i = 0; i < a->dim; i++) + rx[i] = Float4ToFp8Unchecked(Fp8ToFloat4(ax[i]) / norm); + + /* Check for overflow */ + for (int i = 0; i < a->dim; i++) + { + if (Fp8IsNan(rx[i])) + float_overflow_error(); + } + } + + PG_RETURN_POINTER(result); +} + +/* + * Add fp8 vectors + */ +FUNCTION_PREFIX PG_FUNCTION_INFO_V1(minivec_add); +Datum +minivec_add(PG_FUNCTION_ARGS) +{ + MiniVector *a = PG_GETARG_MINIVEC_P(0); + MiniVector *b = PG_GETARG_MINIVEC_P(1); + fp8 *ax = a->x; + fp8 *bx = b->x; + MiniVector *result; + fp8 *rx; + + CheckDims(a, b); + + result = InitMiniVector(a->dim); + rx = result->x; + + /* Auto-vectorized */ + for (int i = 0, imax = a->dim; i < imax; i++) + { +#ifdef FLT16_SUPPORT + rx[i] = ax[i] + bx[i]; +#else + rx[i] = Float4ToFp8Unchecked(Fp8ToFloat4(ax[i]) + Fp8ToFloat4(bx[i])); +#endif + } + + /* Check for overflow */ + for (int i = 0, imax = a->dim; i < imax; i++) + { + if (Fp8IsNan(rx[i])) + float_overflow_error(); + } + + PG_RETURN_POINTER(result); +} + +/* + * Subtract fp8 vectors + */ +FUNCTION_PREFIX PG_FUNCTION_INFO_V1(minivec_sub); +Datum +minivec_sub(PG_FUNCTION_ARGS) +{ + MiniVector *a = PG_GETARG_MINIVEC_P(0); + MiniVector *b = PG_GETARG_MINIVEC_P(1); + fp8 *ax = a->x; + fp8 *bx = b->x; + MiniVector *result; + fp8 *rx; + + CheckDims(a, b); + + result = InitMiniVector(a->dim); + rx = result->x; + + /* Auto-vectorized */ + for (int i = 0, imax = a->dim; i < imax; i++) + { +#ifdef FLT16_SUPPORT + rx[i] = ax[i] - bx[i]; +#else + rx[i] = Float4ToFp8Unchecked(Fp8ToFloat4(ax[i]) - Fp8ToFloat4(bx[i])); +#endif + } + + /* Check for overflow */ + for (int i = 0, imax = a->dim; i < imax; i++) + { + if (Fp8IsNan(rx[i])) + float_overflow_error(); + } + + PG_RETURN_POINTER(result); +} + +/* + * Multiply fp8 vectors + */ +FUNCTION_PREFIX PG_FUNCTION_INFO_V1(minivec_mul); +Datum +minivec_mul(PG_FUNCTION_ARGS) +{ + MiniVector *a = PG_GETARG_MINIVEC_P(0); + MiniVector *b = PG_GETARG_MINIVEC_P(1); + fp8 *ax = a->x; + fp8 *bx = b->x; + MiniVector *result; + fp8 *rx; + + CheckDims(a, b); + + result = InitMiniVector(a->dim); + rx = result->x; + + /* Auto-vectorized */ + for (int i = 0, imax = a->dim; i < imax; i++) + { +#ifdef FLT16_SUPPORT + rx[i] = ax[i] * bx[i]; +#else + rx[i] = Float4ToFp8Unchecked(Fp8ToFloat4(ax[i]) * Fp8ToFloat4(bx[i])); +#endif + } + + /* Check for overflow and underflow */ + for (int i = 0, imax = a->dim; i < imax; i++) + { + if (Fp8IsNan(rx[i])) + float_overflow_error(); + + if (Fp8IsZero(rx[i]) && !(Fp8IsZero(ax[i]) || Fp8IsZero(bx[i]))) + float_underflow_error(); + } + + PG_RETURN_POINTER(result); +} + +/* + * Concatenate fp8 vectors + */ +FUNCTION_PREFIX PG_FUNCTION_INFO_V1(minivec_concat); +Datum +minivec_concat(PG_FUNCTION_ARGS) +{ + MiniVector *a = PG_GETARG_MINIVEC_P(0); + MiniVector *b = PG_GETARG_MINIVEC_P(1); + MiniVector *result; + int dim = a->dim + b->dim; + + CheckDim(dim); + result = InitMiniVector(dim); + + for (int i = 0; i < a->dim; i++) + result->x[i] = a->x[i]; + + for (int i = 0; i < b->dim; i++) + result->x[i + a->dim] = b->x[i]; + + PG_RETURN_POINTER(result); +} + +/* + * Quantize a fp8 vector + */ +FUNCTION_PREFIX PG_FUNCTION_INFO_V1(minivec_binary_quantize); +Datum +minivec_binary_quantize(PG_FUNCTION_ARGS) +{ + MiniVector *a = PG_GETARG_MINIVEC_P(0); + fp8 *ax = a->x; + VarBit *result = InitBitVector(a->dim); + unsigned char *rx = VARBITS(result); + + for (int i = 0; i < a->dim; i++) + rx[i / 8] |= (Fp8ToFloat4(ax[i]) > 0) << (7 - (i % 8)); + + PG_RETURN_VARBIT_P(result); +} + +/* + * Get a subvector + */ +FUNCTION_PREFIX PG_FUNCTION_INFO_V1(minivec_subvector); +Datum +minivec_subvector(PG_FUNCTION_ARGS) +{ + MiniVector *a = PG_GETARG_MINIVEC_P(0); + int32 start = PG_GETARG_INT32(1); + int32 count = PG_GETARG_INT32(2); + int32 end; + fp8 *ax = a->x; + MiniVector *result; + int32 dim; + + if (count < 1) + ereport(ERROR, + (errcode(ERRCODE_DATA_EXCEPTION), + errmsg("minivec must have at least 1 dimension"))); + + /* + * Check if (start + count > a->dim), avoiding integer overflow. a->dim + * and count are both positive, so a->dim - count won't overflow. + */ + if (start > a->dim - count) + end = a->dim + 1; + else + end = start + count; + + /* Indexing starts at 1, like substring */ + if (start < 1) + start = 1; + else if (start > a->dim) + ereport(ERROR, + (errcode(ERRCODE_DATA_EXCEPTION), + errmsg("minivec must have at least 1 dimension"))); + + dim = end - start; + CheckDim(dim); + result = InitMiniVector(dim); + + for (int i = 0; i < dim; i++) + result->x[i] = ax[start - 1 + i]; + + PG_RETURN_POINTER(result); +} + +/* + * Internal helper to compare fp8 vectors + */ +static int +minivec_cmp_internal(MiniVector * a, MiniVector * b) +{ + int dim = Min(a->dim, b->dim); + + /* Check values before dimensions to be consistent with Postgres arrays */ + for (int i = 0; i < dim; i++) + { + if (Fp8ToFloat4(a->x[i]) < Fp8ToFloat4(b->x[i])) + return -1; + + if (Fp8ToFloat4(a->x[i]) > Fp8ToFloat4(b->x[i])) + return 1; + } + + if (a->dim < b->dim) + return -1; + + if (a->dim > b->dim) + return 1; + + return 0; +} + +/* + * Less than + */ +FUNCTION_PREFIX PG_FUNCTION_INFO_V1(minivec_lt); +Datum +minivec_lt(PG_FUNCTION_ARGS) +{ + MiniVector *a = PG_GETARG_MINIVEC_P(0); + MiniVector *b = PG_GETARG_MINIVEC_P(1); + + PG_RETURN_BOOL(minivec_cmp_internal(a, b) < 0); +} + +/* + * Less than or equal + */ +FUNCTION_PREFIX PG_FUNCTION_INFO_V1(minivec_le); +Datum +minivec_le(PG_FUNCTION_ARGS) +{ + MiniVector *a = PG_GETARG_MINIVEC_P(0); + MiniVector *b = PG_GETARG_MINIVEC_P(1); + + PG_RETURN_BOOL(minivec_cmp_internal(a, b) <= 0); +} + +/* + * Equal + */ +FUNCTION_PREFIX PG_FUNCTION_INFO_V1(minivec_eq); +Datum +minivec_eq(PG_FUNCTION_ARGS) +{ + MiniVector *a = PG_GETARG_MINIVEC_P(0); + MiniVector *b = PG_GETARG_MINIVEC_P(1); + + PG_RETURN_BOOL(minivec_cmp_internal(a, b) == 0); +} + +/* + * Not equal + */ +FUNCTION_PREFIX PG_FUNCTION_INFO_V1(minivec_ne); +Datum +minivec_ne(PG_FUNCTION_ARGS) +{ + MiniVector *a = PG_GETARG_MINIVEC_P(0); + MiniVector *b = PG_GETARG_MINIVEC_P(1); + + PG_RETURN_BOOL(minivec_cmp_internal(a, b) != 0); +} + +/* + * Greater than or equal + */ +FUNCTION_PREFIX PG_FUNCTION_INFO_V1(minivec_ge); +Datum +minivec_ge(PG_FUNCTION_ARGS) +{ + MiniVector *a = PG_GETARG_MINIVEC_P(0); + MiniVector *b = PG_GETARG_MINIVEC_P(1); + + PG_RETURN_BOOL(minivec_cmp_internal(a, b) >= 0); +} + +/* + * Greater than + */ +FUNCTION_PREFIX PG_FUNCTION_INFO_V1(minivec_gt); +Datum +minivec_gt(PG_FUNCTION_ARGS) +{ + MiniVector *a = PG_GETARG_MINIVEC_P(0); + MiniVector *b = PG_GETARG_MINIVEC_P(1); + + PG_RETURN_BOOL(minivec_cmp_internal(a, b) > 0); +} + +/* + * Compare fp8 vectors + */ +FUNCTION_PREFIX PG_FUNCTION_INFO_V1(minivec_cmp); +Datum +minivec_cmp(PG_FUNCTION_ARGS) +{ + MiniVector *a = PG_GETARG_MINIVEC_P(0); + MiniVector *b = PG_GETARG_MINIVEC_P(1); + + PG_RETURN_INT32(minivec_cmp_internal(a, b)); +} diff --git a/src/minivec.h b/src/minivec.h new file mode 100644 index 0000000..ddca3fc --- /dev/null +++ b/src/minivec.h @@ -0,0 +1,156 @@ +#ifndef MINIVEC_H +#define MINIVEC_H + +#include + +#define MINIVEC_MAX_DIM 16000 + +#define fp8 uint8 + +#define MINIVEC_SIZE(_dim) (offsetof(MiniVector, x) + sizeof(fp8)*(_dim)) +#define DatumGetMiniVector(x) ((MiniVector *) PG_DETOAST_DATUM(x)) +#define PG_GETARG_MINIVEC_P(x) DatumGetMiniVector(PG_GETARG_DATUM(x)) +#define PG_RETURN_MINIVEC_P(x) PG_RETURN_POINTER(x) + +typedef struct MiniVector +{ + int32 vl_len_; /* varlena header (do not touch directly!) */ + int16 dim; /* number of dimensions */ + int16 unused; /* reserved for future use, always zero */ + fp8 x[FLEXIBLE_ARRAY_MEMBER]; +} MiniVector; + +MiniVector *InitMiniVector(int dim); + +/* + * Check if fp8 is NaN + */ +static inline bool +Fp8IsNan(fp8 num) +{ + return (num & 0x7F) == 0x7F; +} + +/* + * Check if fp8 is zero + */ +static inline bool +Fp8IsZero(fp8 num) +{ + return num == 0; +} + +/* + * Convert a fp8 to a float4 + */ +static inline float +Fp8ToFloat4(fp8 num) +{ + float lookup[128] = {0, 0.00195312, 0.00390625, 0.00585938, 0.0078125, 0.00976562, 0.0117188, 0.0136719, 0.015625, 0.0175781, 0.0195312, 0.0214844, 0.0234375, 0.0253906, 0.0273438, 0.0292969, 0.03125, 0.0351562, 0.0390625, 0.0429688, 0.046875, 0.0507812, 0.0546875, 0.0585938, 0.0625, 0.0703125, 0.078125, 0.0859375, 0.09375, 0.101562, 0.109375, 0.117188, 0.125, 0.140625, 0.15625, 0.171875, 0.1875, 0.203125, 0.21875, 0.234375, 0.25, 0.28125, 0.3125, 0.34375, 0.375, 0.40625, 0.4375, 0.46875, 0.5, 0.5625, 0.625, 0.6875, 0.75, 0.8125, 0.875, 0.9375, 1, 1.125, 1.25, 1.375, 1.5, 1.625, 1.75, 1.875, 2, 2.25, 2.5, 2.75, 3, 3.25, 3.5, 3.75, 4, 4.5, 5, 5.5, 6, 6.5, 7, 7.5, 8, 9, 10, 11, 12, 13, 14, 15, 16, 18, 20, 22, 24, 26, 28, 30, 32, 36, 40, 44, 48, 52, 56, 60, 64, 72, 80, 88, 96, 104, 112, 120, 128, 144, 160, 176, 192, 208, 224, 240, 256, 288, 320, 352, 384, 416, 448, NAN}; + float v = lookup[num & 0x7F]; + + return (num & 0x80) == 0x80 ? -v : v; +} + +/* + * Convert a float4 to a fp8 + */ +static inline fp8 +Float4ToFp8Unchecked(float num) +{ + union + { + float f; + uint32 i; + } swapfloat; + + uint32 bin; + int exponent; + int mantissa; + uint8 result; + + swapfloat.f = num; + bin = swapfloat.i; + exponent = (bin & 0x7F800000) >> 23; + mantissa = bin & 0x007FFFFF; + + /* Sign */ + result = (bin & 0x80000000) >> 24; + + if (isinf(num) || isnan(num)) + { + /* NaN */ + result |= 0x7F; + } + else if (exponent > 116) + { + int m; + int gr; + int s; + + exponent -= 127; + s = mantissa & 0x000FFFFF; + + /* Subnormal */ + if (exponent < -6) + { + int diff = -exponent - 6; + + mantissa >>= diff; + mantissa += 1 << (23 - diff); + s |= mantissa & 0x000FFFFF; + } + + m = mantissa >> 20; + + /* Round */ + gr = (mantissa >> 19) % 4; + if (gr == 3 || (gr == 1 && s != 0)) + m += 1; + + if (m == 8) + { + m = 0; + exponent += 1; + } + + if (exponent > 8) + { + /* Infinite, which is NaN */ + result |= 0x7F; + } + else + { + if (exponent >= -7) + result |= (exponent + 7) << 3; + + result |= m; + } + } + + return result; +} + +/* + * Convert a float4 to a fp8 + */ +static inline fp8 +Float4ToFp8(float num) +{ + fp8 result = Float4ToFp8Unchecked(num); + + if (unlikely(Fp8IsNan(result)) && !isnan(num)) + { + char *buf = palloc(FLOAT_SHORTEST_DECIMAL_LEN); + + float_to_shortest_decimal_buf(num, buf); + + ereport(ERROR, + (errcode(ERRCODE_NUMERIC_VALUE_OUT_OF_RANGE), + errmsg("\"%s\" is out of range for type minivec", buf))); + } + + return result; +} + +#endif diff --git a/src/vector.c b/src/vector.c index a5b2aac..2c421c9 100644 --- a/src/vector.c +++ b/src/vector.c @@ -13,6 +13,7 @@ #include "ivfflat.h" #include "lib/stringinfo.h" #include "libpq/pqformat.h" +#include "minivec.h" #include "port.h" /* for strtof() */ #include "sparsevec.h" #include "utils/array.h" @@ -542,6 +543,28 @@ halfvec_to_vector(PG_FUNCTION_ARGS) PG_RETURN_POINTER(result); } +/* + * Convert fp8 vector to vector + */ +FUNCTION_PREFIX PG_FUNCTION_INFO_V1(minivec_to_vector); +Datum +minivec_to_vector(PG_FUNCTION_ARGS) +{ + MiniVector *vec = PG_GETARG_MINIVEC_P(0); + int32 typmod = PG_GETARG_INT32(1); + Vector *result; + + CheckDim(vec->dim); + CheckExpectedDim(typmod, vec->dim); + + result = InitVector(vec->dim); + + for (int i = 0; i < vec->dim; i++) + result->x[i] = Fp8ToFloat4(vec->x[i]); + + PG_RETURN_POINTER(result); +} + VECTOR_TARGET_CLONES static float VectorL2SquaredDistance(int dim, float *ax, float *bx) { diff --git a/test/expected/btree.out b/test/expected/btree.out index 999a160..e90008d 100644 --- a/test/expected/btree.out +++ b/test/expected/btree.out @@ -38,6 +38,26 @@ SELECT * FROM t ORDER BY val; (4 rows) +DROP TABLE t; +-- minivec +CREATE TABLE t (val minivec(3)); +INSERT INTO t (val) VALUES ('[0,0,0]'), ('[1,2,3]'), ('[1,1,1]'), (NULL); +CREATE INDEX ON t (val); +SELECT * FROM t WHERE val = '[1,2,3]'; + val +--------- + [1,2,3] +(1 row) + +SELECT * FROM t ORDER BY val; + val +--------- + [0,0,0] + [1,1,1] + [1,2,3] + +(4 rows) + DROP TABLE t; -- sparsevec CREATE TABLE t (val sparsevec(3)); diff --git a/test/expected/cast.out b/test/expected/cast.out index c180fe6..34a57c3 100644 --- a/test/expected/cast.out +++ b/test/expected/cast.out @@ -140,6 +140,64 @@ SELECT '{1e-8,-1e-8}'::real[]::halfvec; [0,-0] (1 row) +SELECT '[1,2,3]'::vector::minivec; + minivec +--------- + [1,2,3] +(1 row) + +SELECT '[1,2,3]'::vector::minivec(3); + minivec +--------- + [1,2,3] +(1 row) + +SELECT '[1,2,3]'::vector::minivec(2); +ERROR: expected 2 dimensions, not 3 +SELECT '[465]'::vector::minivec; +ERROR: "465" is out of range for type minivec +SELECT '[1e-8]'::vector::minivec; + minivec +--------- + [0] +(1 row) + +SELECT '[1,2,3]'::minivec::vector; + vector +--------- + [1,2,3] +(1 row) + +SELECT '[1,2,3]'::minivec::vector(3); + vector +--------- + [1,2,3] +(1 row) + +SELECT '[1,2,3]'::minivec::vector(2); +ERROR: expected 2 dimensions, not 3 +SELECT '{1,2,3}'::real[]::minivec; + minivec +--------- + [1,2,3] +(1 row) + +SELECT '{1,2,3}'::real[]::minivec(3); + minivec +--------- + [1,2,3] +(1 row) + +SELECT '{1,2,3}'::real[]::minivec(2); +ERROR: expected 2 dimensions, not 3 +SELECT '{465,-465}'::real[]::minivec; +ERROR: "465" is out of range for type minivec +SELECT '{1e-8,-1e-8}'::real[]::minivec; + minivec +--------- + [0,-0] +(1 row) + SELECT '[0,1.5,0,3.5,0]'::vector::sparsevec; sparsevec ----------------- diff --git a/test/expected/copy.out b/test/expected/copy.out index 9b4ebc0..59393a8 100644 --- a/test/expected/copy.out +++ b/test/expected/copy.out @@ -30,6 +30,23 @@ SELECT * FROM t2 ORDER BY val; (4 rows) +DROP TABLE t; +DROP TABLE t2; +-- minivec +CREATE TABLE t (val minivec(3)); +INSERT INTO t (val) VALUES ('[0,0,0]'), ('[1,2,3]'), ('[1,1,1]'), (NULL); +CREATE TABLE t2 (val minivec(3)); +\copy t TO 'results/minivec.bin' WITH (FORMAT binary) +\copy t2 FROM 'results/minivec.bin' WITH (FORMAT binary) +SELECT * FROM t2 ORDER BY val; + val +--------- + [0,0,0] + [1,1,1] + [1,2,3] + +(4 rows) + DROP TABLE t; DROP TABLE t2; -- sparsevec diff --git a/test/expected/minivec.out b/test/expected/minivec.out new file mode 100644 index 0000000..ab5224c --- /dev/null +++ b/test/expected/minivec.out @@ -0,0 +1,588 @@ +SELECT '[1,2,3]'::minivec; + minivec +--------- + [1,2,3] +(1 row) + +SELECT '[-1,-2,-3]'::minivec; + minivec +------------ + [-1,-2,-3] +(1 row) + +SELECT '[1.,2.,3.]'::minivec; + minivec +--------- + [1,2,3] +(1 row) + +SELECT ' [ 1, 2 , 3 ] '::minivec; + minivec +--------- + [1,2,3] +(1 row) + +SELECT '[1.23456]'::minivec; + minivec +--------- + [1.25] +(1 row) + +SELECT '[hello,1]'::minivec; +ERROR: invalid input syntax for type minivec: "[hello,1]" +LINE 1: SELECT '[hello,1]'::minivec; + ^ +SELECT '[NaN,1]'::minivec; +ERROR: NaN not allowed in minivec +LINE 1: SELECT '[NaN,1]'::minivec; + ^ +SELECT '[Infinity,1]'::minivec; +ERROR: "Infinity" is out of range for type minivec +LINE 1: SELECT '[Infinity,1]'::minivec; + ^ +SELECT '[-Infinity,1]'::minivec; +ERROR: "-Infinity" is out of range for type minivec +LINE 1: SELECT '[-Infinity,1]'::minivec; + ^ +SELECT '[65519,-65519]'::minivec; +ERROR: "65519" is out of range for type minivec +LINE 1: SELECT '[65519,-65519]'::minivec; + ^ +SELECT '[65520,-65520]'::minivec; +ERROR: "65520" is out of range for type minivec +LINE 1: SELECT '[65520,-65520]'::minivec; + ^ +SELECT '[1e-8,-1e-8]'::minivec; + minivec +--------- + [0,-0] +(1 row) + +SELECT '[4e38,1]'::minivec; +ERROR: "4e38" is out of range for type minivec +LINE 1: SELECT '[4e38,1]'::minivec; + ^ +SELECT '[1e-46,1]'::minivec; + minivec +--------- + [0,1] +(1 row) + +SELECT '[1,2,3'::minivec; +ERROR: invalid input syntax for type minivec: "[1,2,3" +LINE 1: SELECT '[1,2,3'::minivec; + ^ +SELECT '[1,2,3]9'::minivec; +ERROR: invalid input syntax for type minivec: "[1,2,3]9" +LINE 1: SELECT '[1,2,3]9'::minivec; + ^ +DETAIL: Junk after closing right brace. +SELECT '1,2,3'::minivec; +ERROR: invalid input syntax for type minivec: "1,2,3" +LINE 1: SELECT '1,2,3'::minivec; + ^ +DETAIL: Vector contents must start with "[". +SELECT ''::minivec; +ERROR: invalid input syntax for type minivec: "" +LINE 1: SELECT ''::minivec; + ^ +DETAIL: Vector contents must start with "[". +SELECT '['::minivec; +ERROR: invalid input syntax for type minivec: "[" +LINE 1: SELECT '['::minivec; + ^ +SELECT '[ '::minivec; +ERROR: invalid input syntax for type minivec: "[ " +LINE 1: SELECT '[ '::minivec; + ^ +SELECT '[,'::minivec; +ERROR: invalid input syntax for type minivec: "[," +LINE 1: SELECT '[,'::minivec; + ^ +SELECT '[]'::minivec; +ERROR: minivec must have at least 1 dimension +LINE 1: SELECT '[]'::minivec; + ^ +SELECT '[ ]'::minivec; +ERROR: minivec must have at least 1 dimension +LINE 1: SELECT '[ ]'::minivec; + ^ +SELECT '[,]'::minivec; +ERROR: invalid input syntax for type minivec: "[,]" +LINE 1: SELECT '[,]'::minivec; + ^ +SELECT '[1,]'::minivec; +ERROR: invalid input syntax for type minivec: "[1,]" +LINE 1: SELECT '[1,]'::minivec; + ^ +SELECT '[1a]'::minivec; +ERROR: invalid input syntax for type minivec: "[1a]" +LINE 1: SELECT '[1a]'::minivec; + ^ +SELECT '[1,,3]'::minivec; +ERROR: invalid input syntax for type minivec: "[1,,3]" +LINE 1: SELECT '[1,,3]'::minivec; + ^ +SELECT '[1, ,3]'::minivec; +ERROR: invalid input syntax for type minivec: "[1, ,3]" +LINE 1: SELECT '[1, ,3]'::minivec; + ^ +SELECT '[1,2,3]'::minivec(3); + minivec +--------- + [1,2,3] +(1 row) + +SELECT '[1,2,3]'::minivec(2); +ERROR: expected 2 dimensions, not 3 +SELECT '[1,2,3]'::minivec(3, 2); +ERROR: invalid type modifier +LINE 1: SELECT '[1,2,3]'::minivec(3, 2); + ^ +SELECT '[1,2,3]'::minivec('a'); +ERROR: invalid input syntax for type integer: "a" +LINE 1: SELECT '[1,2,3]'::minivec('a'); + ^ +SELECT '[1,2,3]'::minivec(0); +ERROR: dimensions for type minivec must be at least 1 +LINE 1: SELECT '[1,2,3]'::minivec(0); + ^ +SELECT '[1,2,3]'::minivec(16001); +ERROR: dimensions for type minivec cannot exceed 16000 +LINE 1: SELECT '[1,2,3]'::minivec(16001); + ^ +SELECT unnest('{"[1,2,3]", "[4,5,6]"}'::minivec[]); + unnest +--------- + [1,2,3] + [4,5,6] +(2 rows) + +SELECT '{"[1,2,3]"}'::minivec(2)[]; +ERROR: expected 2 dimensions, not 3 +SELECT '[1,2,3]'::minivec + '[4,5,6]'; + ?column? +---------- + [5,7,9] +(1 row) + +SELECT '[448]'::minivec + '[448]'; +ERROR: value out of range: overflow +SELECT '[1,2]'::minivec + '[3]'; +ERROR: different minivec dimensions 2 and 1 +SELECT '[1,2,3]'::minivec - '[4,5,6]'; + ?column? +------------ + [-3,-3,-3] +(1 row) + +SELECT '[-448]'::minivec - '[448]'; +ERROR: value out of range: overflow +SELECT '[1,2]'::minivec - '[3]'; +ERROR: different minivec dimensions 2 and 1 +SELECT '[1,2,3]'::minivec * '[4,5,6]'; + ?column? +----------- + [4,10,18] +(1 row) + +SELECT '[448]'::minivec * '[448]'; +ERROR: value out of range: overflow +SELECT '[1e-7]'::minivec * '[1e-7]'; + ?column? +---------- + [0] +(1 row) + +SELECT '[1,2]'::minivec * '[3]'; +ERROR: different minivec dimensions 2 and 1 +SELECT '[1,2,3]'::minivec || '[4,5]'; + ?column? +------------- + [1,2,3,4,5] +(1 row) + +SELECT array_fill(0, ARRAY[16000])::minivec || '[1]'; +ERROR: minivec cannot have more than 16000 dimensions +SELECT '[1,2,3]'::minivec < '[1,2,3]'; + ?column? +---------- + f +(1 row) + +SELECT '[1,2,3]'::minivec < '[1,2]'; + ?column? +---------- + f +(1 row) + +SELECT '[1,2,3]'::minivec <= '[1,2,3]'; + ?column? +---------- + t +(1 row) + +SELECT '[1,2,3]'::minivec <= '[1,2]'; + ?column? +---------- + f +(1 row) + +SELECT '[1,2,3]'::minivec = '[1,2,3]'; + ?column? +---------- + t +(1 row) + +SELECT '[1,2,3]'::minivec = '[1,2]'; + ?column? +---------- + f +(1 row) + +SELECT '[1,2,3]'::minivec != '[1,2,3]'; + ?column? +---------- + f +(1 row) + +SELECT '[1,2,3]'::minivec != '[1,2]'; + ?column? +---------- + t +(1 row) + +SELECT '[1,2,3]'::minivec >= '[1,2,3]'; + ?column? +---------- + t +(1 row) + +SELECT '[1,2,3]'::minivec >= '[1,2]'; + ?column? +---------- + t +(1 row) + +SELECT '[1,2,3]'::minivec > '[1,2,3]'; + ?column? +---------- + f +(1 row) + +SELECT '[1,2,3]'::minivec > '[1,2]'; + ?column? +---------- + t +(1 row) + +SELECT minivec_cmp('[1,2,3]', '[1,2,3]'); + minivec_cmp +------------- + 0 +(1 row) + +SELECT minivec_cmp('[1,2,3]', '[0,0,0]'); + minivec_cmp +------------- + 1 +(1 row) + +SELECT minivec_cmp('[0,0,0]', '[1,2,3]'); + minivec_cmp +------------- + -1 +(1 row) + +SELECT minivec_cmp('[1,2]', '[1,2,3]'); + minivec_cmp +------------- + -1 +(1 row) + +SELECT minivec_cmp('[1,2,3]', '[1,2]'); + minivec_cmp +------------- + 1 +(1 row) + +SELECT minivec_cmp('[1,2]', '[2,3,4]'); + minivec_cmp +------------- + -1 +(1 row) + +SELECT minivec_cmp('[2,3]', '[1,2,3]'); + minivec_cmp +------------- + 1 +(1 row) + +SELECT vector_dims('[1,2,3]'::minivec); + vector_dims +------------- + 3 +(1 row) + +SELECT round(l2_norm('[1,1]'::minivec)::numeric, 5); + round +--------- + 1.41421 +(1 row) + +SELECT l2_norm('[3,4]'::minivec); + l2_norm +--------- + 5 +(1 row) + +SELECT l2_norm('[0,1]'::minivec); + l2_norm +--------- + 1 +(1 row) + +SELECT l2_norm('[0,0]'::minivec); + l2_norm +--------- + 0 +(1 row) + +SELECT l2_norm('[2]'::minivec); + l2_norm +--------- + 2 +(1 row) + +SELECT l2_distance('[0,0]'::minivec, '[3,4]'); + l2_distance +------------- + 5 +(1 row) + +SELECT l2_distance('[0,0]'::minivec, '[0,1]'); + l2_distance +------------- + 1 +(1 row) + +SELECT l2_distance('[1,2]'::minivec, '[3]'); +ERROR: different minivec dimensions 2 and 1 +SELECT l2_distance('[1,1,1,1,1,1,1,1,1]'::minivec, '[1,1,1,1,1,1,1,4,5]'); + l2_distance +------------- + 5 +(1 row) + +SELECT '[0,0]'::minivec <-> '[3,4]'; + ?column? +---------- + 5 +(1 row) + +SELECT inner_product('[1,2]'::minivec, '[3,4]'); + inner_product +--------------- + 11 +(1 row) + +SELECT inner_product('[1,2]'::minivec, '[3]'); +ERROR: different minivec dimensions 2 and 1 +SELECT inner_product('[448]'::minivec, '[448]'); + inner_product +--------------- + 200704 +(1 row) + +SELECT inner_product('[1,1,1,1,1,1,1,1,1]'::minivec, '[1,2,3,4,5,6,7,8,9]'); + inner_product +--------------- + 45 +(1 row) + +SELECT '[1,2]'::minivec <#> '[3,4]'; + ?column? +---------- + -11 +(1 row) + +SELECT cosine_distance('[1,2]'::minivec, '[2,4]'); + cosine_distance +----------------- + 0 +(1 row) + +SELECT cosine_distance('[1,2]'::minivec, '[0,0]'); + cosine_distance +----------------- + NaN +(1 row) + +SELECT cosine_distance('[1,1]'::minivec, '[1,1]'); + cosine_distance +----------------- + 0 +(1 row) + +SELECT cosine_distance('[1,0]'::minivec, '[0,2]'); + cosine_distance +----------------- + 1 +(1 row) + +SELECT cosine_distance('[1,1]'::minivec, '[-1,-1]'); + cosine_distance +----------------- + 2 +(1 row) + +SELECT cosine_distance('[1,2]'::minivec, '[3]'); +ERROR: different minivec dimensions 2 and 1 +SELECT cosine_distance('[1,1]'::minivec, '[1.1,1.1]'); + cosine_distance +----------------- + 0 +(1 row) + +SELECT cosine_distance('[1,1]'::minivec, '[-1.1,-1.1]'); + cosine_distance +----------------- + 2 +(1 row) + +SELECT cosine_distance('[1,2,3,4,5,6,7,8,9]'::minivec, '[1,2,3,4,5,6,7,8,9]'); + cosine_distance +----------------- + 0 +(1 row) + +SELECT cosine_distance('[1,2,3,4,5,6,7,8,9]'::minivec, '[-1,-2,-3,-4,-5,-6,-7,-8,-9]'); + cosine_distance +----------------- + 2 +(1 row) + +SELECT '[1,2]'::minivec <=> '[2,4]'; + ?column? +---------- + 0 +(1 row) + +SELECT l1_distance('[0,0]'::minivec, '[3,4]'); + l1_distance +------------- + 7 +(1 row) + +SELECT l1_distance('[0,0]'::minivec, '[0,1]'); + l1_distance +------------- + 1 +(1 row) + +SELECT l1_distance('[1,2]'::minivec, '[3]'); +ERROR: different minivec dimensions 2 and 1 +SELECT l1_distance('[1,2,3,4,5,6,7,8,9]'::minivec, '[1,2,3,4,5,6,7,8,9]'); + l1_distance +------------- + 0 +(1 row) + +SELECT l1_distance('[1,2,3,4,5,6,7,8,9]'::minivec, '[0,3,2,5,4,7,6,9,8]'); + l1_distance +------------- + 9 +(1 row) + +SELECT '[0,0]'::minivec <+> '[3,4]'; + ?column? +---------- + 7 +(1 row) + +SELECT l2_normalize('[3,4]'::minivec); + l2_normalize +---------------- + [0.625,0.8125] +(1 row) + +SELECT l2_normalize('[3,0]'::minivec); + l2_normalize +-------------- + [1,0] +(1 row) + +SELECT l2_normalize('[0,0.1]'::minivec); + l2_normalize +-------------- + [0,1] +(1 row) + +SELECT l2_normalize('[0,0]'::minivec); + l2_normalize +-------------- + [0,0] +(1 row) + +SELECT l2_normalize('[448]'::minivec); + l2_normalize +-------------- + [1] +(1 row) + +SELECT binary_quantize('[1,0,-1]'::minivec); + binary_quantize +----------------- + 100 +(1 row) + +SELECT binary_quantize('[0,0.1,-0.2,-0.3,0.4,0.5,0.6,-0.7,0.8,-0.9,1]'::minivec); + binary_quantize +----------------- + 01001110101 +(1 row) + +SELECT subvector('[1,2,3,4,5]'::minivec, 1, 3); + subvector +----------- + [1,2,3] +(1 row) + +SELECT subvector('[1,2,3,4,5]'::minivec, 3, 2); + subvector +----------- + [3,4] +(1 row) + +SELECT subvector('[1,2,3,4,5]'::minivec, -1, 3); + subvector +----------- + [1] +(1 row) + +SELECT subvector('[1,2,3,4,5]'::minivec, 3, 9); + subvector +----------- + [3,4,5] +(1 row) + +SELECT subvector('[1,2,3,4,5]'::minivec, 1, 0); +ERROR: minivec must have at least 1 dimension +SELECT subvector('[1,2,3,4,5]'::minivec, 3, -1); +ERROR: minivec must have at least 1 dimension +SELECT subvector('[1,2,3,4,5]'::minivec, -1, 2); +ERROR: minivec must have at least 1 dimension +SELECT subvector('[1,2,3,4,5]'::minivec, 2147483647, 10); +ERROR: minivec must have at least 1 dimension +SELECT subvector('[1,2,3,4,5]'::minivec, 3, 2147483647); + subvector +----------- + [3,4,5] +(1 row) + +SELECT subvector('[1,2,3,4,5]'::minivec, -2147483644, 2147483647); + subvector +----------- + [1,2] +(1 row) + diff --git a/test/sql/btree.sql b/test/sql/btree.sql index de583c3..6a4c69c 100644 --- a/test/sql/btree.sql +++ b/test/sql/btree.sql @@ -22,6 +22,17 @@ SELECT * FROM t ORDER BY val; DROP TABLE t; +-- minivec + +CREATE TABLE t (val minivec(3)); +INSERT INTO t (val) VALUES ('[0,0,0]'), ('[1,2,3]'), ('[1,1,1]'), (NULL); +CREATE INDEX ON t (val); + +SELECT * FROM t WHERE val = '[1,2,3]'; +SELECT * FROM t ORDER BY val; + +DROP TABLE t; + -- sparsevec CREATE TABLE t (val sparsevec(3)); diff --git a/test/sql/cast.sql b/test/sql/cast.sql index fe83931..5db8436 100644 --- a/test/sql/cast.sql +++ b/test/sql/cast.sql @@ -38,6 +38,22 @@ SELECT '{1,2,3}'::real[]::halfvec(2); SELECT '{65520,-65520}'::real[]::halfvec; SELECT '{1e-8,-1e-8}'::real[]::halfvec; +SELECT '[1,2,3]'::vector::minivec; +SELECT '[1,2,3]'::vector::minivec(3); +SELECT '[1,2,3]'::vector::minivec(2); +SELECT '[465]'::vector::minivec; +SELECT '[1e-8]'::vector::minivec; + +SELECT '[1,2,3]'::minivec::vector; +SELECT '[1,2,3]'::minivec::vector(3); +SELECT '[1,2,3]'::minivec::vector(2); + +SELECT '{1,2,3}'::real[]::minivec; +SELECT '{1,2,3}'::real[]::minivec(3); +SELECT '{1,2,3}'::real[]::minivec(2); +SELECT '{465,-465}'::real[]::minivec; +SELECT '{1e-8,-1e-8}'::real[]::minivec; + SELECT '[0,1.5,0,3.5,0]'::vector::sparsevec; SELECT '[0,1.5,0,3.5,0]'::vector::sparsevec(5); SELECT '[0,1.5,0,3.5,0]'::vector::sparsevec(4); diff --git a/test/sql/copy.sql b/test/sql/copy.sql index 2dff3ff..1012801 100644 --- a/test/sql/copy.sql +++ b/test/sql/copy.sql @@ -28,6 +28,21 @@ SELECT * FROM t2 ORDER BY val; DROP TABLE t; DROP TABLE t2; +-- minivec + +CREATE TABLE t (val minivec(3)); +INSERT INTO t (val) VALUES ('[0,0,0]'), ('[1,2,3]'), ('[1,1,1]'), (NULL); + +CREATE TABLE t2 (val minivec(3)); + +\copy t TO 'results/minivec.bin' WITH (FORMAT binary) +\copy t2 FROM 'results/minivec.bin' WITH (FORMAT binary) + +SELECT * FROM t2 ORDER BY val; + +DROP TABLE t; +DROP TABLE t2; + -- sparsevec CREATE TABLE t (val sparsevec(3)); diff --git a/test/sql/minivec.sql b/test/sql/minivec.sql new file mode 100644 index 0000000..22f67d4 --- /dev/null +++ b/test/sql/minivec.sql @@ -0,0 +1,134 @@ +SELECT '[1,2,3]'::minivec; +SELECT '[-1,-2,-3]'::minivec; +SELECT '[1.,2.,3.]'::minivec; +SELECT ' [ 1, 2 , 3 ] '::minivec; +SELECT '[1.23456]'::minivec; +SELECT '[hello,1]'::minivec; +SELECT '[NaN,1]'::minivec; +SELECT '[Infinity,1]'::minivec; +SELECT '[-Infinity,1]'::minivec; +SELECT '[65519,-65519]'::minivec; +SELECT '[65520,-65520]'::minivec; +SELECT '[1e-8,-1e-8]'::minivec; +SELECT '[4e38,1]'::minivec; +SELECT '[1e-46,1]'::minivec; +SELECT '[1,2,3'::minivec; +SELECT '[1,2,3]9'::minivec; +SELECT '1,2,3'::minivec; +SELECT ''::minivec; +SELECT '['::minivec; +SELECT '[ '::minivec; +SELECT '[,'::minivec; +SELECT '[]'::minivec; +SELECT '[ ]'::minivec; +SELECT '[,]'::minivec; +SELECT '[1,]'::minivec; +SELECT '[1a]'::minivec; +SELECT '[1,,3]'::minivec; +SELECT '[1, ,3]'::minivec; + +SELECT '[1,2,3]'::minivec(3); +SELECT '[1,2,3]'::minivec(2); +SELECT '[1,2,3]'::minivec(3, 2); +SELECT '[1,2,3]'::minivec('a'); +SELECT '[1,2,3]'::minivec(0); +SELECT '[1,2,3]'::minivec(16001); + +SELECT unnest('{"[1,2,3]", "[4,5,6]"}'::minivec[]); +SELECT '{"[1,2,3]"}'::minivec(2)[]; + +SELECT '[1,2,3]'::minivec + '[4,5,6]'; +SELECT '[448]'::minivec + '[448]'; +SELECT '[1,2]'::minivec + '[3]'; + +SELECT '[1,2,3]'::minivec - '[4,5,6]'; +SELECT '[-448]'::minivec - '[448]'; +SELECT '[1,2]'::minivec - '[3]'; + +SELECT '[1,2,3]'::minivec * '[4,5,6]'; +SELECT '[448]'::minivec * '[448]'; +SELECT '[1e-7]'::minivec * '[1e-7]'; +SELECT '[1,2]'::minivec * '[3]'; + +SELECT '[1,2,3]'::minivec || '[4,5]'; +SELECT array_fill(0, ARRAY[16000])::minivec || '[1]'; + +SELECT '[1,2,3]'::minivec < '[1,2,3]'; +SELECT '[1,2,3]'::minivec < '[1,2]'; +SELECT '[1,2,3]'::minivec <= '[1,2,3]'; +SELECT '[1,2,3]'::minivec <= '[1,2]'; +SELECT '[1,2,3]'::minivec = '[1,2,3]'; +SELECT '[1,2,3]'::minivec = '[1,2]'; +SELECT '[1,2,3]'::minivec != '[1,2,3]'; +SELECT '[1,2,3]'::minivec != '[1,2]'; +SELECT '[1,2,3]'::minivec >= '[1,2,3]'; +SELECT '[1,2,3]'::minivec >= '[1,2]'; +SELECT '[1,2,3]'::minivec > '[1,2,3]'; +SELECT '[1,2,3]'::minivec > '[1,2]'; + +SELECT minivec_cmp('[1,2,3]', '[1,2,3]'); +SELECT minivec_cmp('[1,2,3]', '[0,0,0]'); +SELECT minivec_cmp('[0,0,0]', '[1,2,3]'); +SELECT minivec_cmp('[1,2]', '[1,2,3]'); +SELECT minivec_cmp('[1,2,3]', '[1,2]'); +SELECT minivec_cmp('[1,2]', '[2,3,4]'); +SELECT minivec_cmp('[2,3]', '[1,2,3]'); + +SELECT vector_dims('[1,2,3]'::minivec); + +SELECT round(l2_norm('[1,1]'::minivec)::numeric, 5); +SELECT l2_norm('[3,4]'::minivec); +SELECT l2_norm('[0,1]'::minivec); +SELECT l2_norm('[0,0]'::minivec); +SELECT l2_norm('[2]'::minivec); + +SELECT l2_distance('[0,0]'::minivec, '[3,4]'); +SELECT l2_distance('[0,0]'::minivec, '[0,1]'); +SELECT l2_distance('[1,2]'::minivec, '[3]'); +SELECT l2_distance('[1,1,1,1,1,1,1,1,1]'::minivec, '[1,1,1,1,1,1,1,4,5]'); +SELECT '[0,0]'::minivec <-> '[3,4]'; + +SELECT inner_product('[1,2]'::minivec, '[3,4]'); +SELECT inner_product('[1,2]'::minivec, '[3]'); +SELECT inner_product('[448]'::minivec, '[448]'); +SELECT inner_product('[1,1,1,1,1,1,1,1,1]'::minivec, '[1,2,3,4,5,6,7,8,9]'); +SELECT '[1,2]'::minivec <#> '[3,4]'; + +SELECT cosine_distance('[1,2]'::minivec, '[2,4]'); +SELECT cosine_distance('[1,2]'::minivec, '[0,0]'); +SELECT cosine_distance('[1,1]'::minivec, '[1,1]'); +SELECT cosine_distance('[1,0]'::minivec, '[0,2]'); +SELECT cosine_distance('[1,1]'::minivec, '[-1,-1]'); +SELECT cosine_distance('[1,2]'::minivec, '[3]'); +SELECT cosine_distance('[1,1]'::minivec, '[1.1,1.1]'); +SELECT cosine_distance('[1,1]'::minivec, '[-1.1,-1.1]'); +SELECT cosine_distance('[1,2,3,4,5,6,7,8,9]'::minivec, '[1,2,3,4,5,6,7,8,9]'); +SELECT cosine_distance('[1,2,3,4,5,6,7,8,9]'::minivec, '[-1,-2,-3,-4,-5,-6,-7,-8,-9]'); +SELECT '[1,2]'::minivec <=> '[2,4]'; + +SELECT l1_distance('[0,0]'::minivec, '[3,4]'); +SELECT l1_distance('[0,0]'::minivec, '[0,1]'); +SELECT l1_distance('[1,2]'::minivec, '[3]'); +SELECT l1_distance('[1,2,3,4,5,6,7,8,9]'::minivec, '[1,2,3,4,5,6,7,8,9]'); +SELECT l1_distance('[1,2,3,4,5,6,7,8,9]'::minivec, '[0,3,2,5,4,7,6,9,8]'); +SELECT '[0,0]'::minivec <+> '[3,4]'; + +SELECT l2_normalize('[3,4]'::minivec); +SELECT l2_normalize('[3,0]'::minivec); +SELECT l2_normalize('[0,0.1]'::minivec); +SELECT l2_normalize('[0,0]'::minivec); +SELECT l2_normalize('[448]'::minivec); + +SELECT binary_quantize('[1,0,-1]'::minivec); +SELECT binary_quantize('[0,0.1,-0.2,-0.3,0.4,0.5,0.6,-0.7,0.8,-0.9,1]'::minivec); + +SELECT subvector('[1,2,3,4,5]'::minivec, 1, 3); +SELECT subvector('[1,2,3,4,5]'::minivec, 3, 2); +SELECT subvector('[1,2,3,4,5]'::minivec, -1, 3); +SELECT subvector('[1,2,3,4,5]'::minivec, 3, 9); +SELECT subvector('[1,2,3,4,5]'::minivec, 1, 0); +SELECT subvector('[1,2,3,4,5]'::minivec, 3, -1); +SELECT subvector('[1,2,3,4,5]'::minivec, -1, 2); +SELECT subvector('[1,2,3,4,5]'::minivec, 2147483647, 10); +SELECT subvector('[1,2,3,4,5]'::minivec, 3, 2147483647); +SELECT subvector('[1,2,3,4,5]'::minivec, -2147483644, 2147483647); diff --git a/test/t/039_hnsw_minivec_build_recall.pl b/test/t/039_hnsw_minivec_build_recall.pl new file mode 100644 index 0000000..60cbaf0 --- /dev/null +++ b/test/t/039_hnsw_minivec_build_recall.pl @@ -0,0 +1,136 @@ +use strict; +use warnings FATAL => 'all'; +use PostgreSQL::Test::Cluster; +use PostgreSQL::Test::Utils; +use Test::More; + +my $node; +my @queries = (); +my @expected; +my $limit = 20; +my $dim = 10; +my $array_sql = join(",", ('2 * random() * random()') x $dim); + +sub test_recall +{ + my ($min, $operator) = @_; + my $correct = 0; + my $total = 0; + + my $explain = $node->safe_psql("postgres", qq( + SET enable_seqscan = off; + EXPLAIN ANALYZE SELECT i FROM tst ORDER BY v $operator '$queries[0]' LIMIT $limit; + )); + like($explain, qr/Index Scan/); + + for my $i (0 .. $#queries) + { + my $actual = $node->safe_psql("postgres", qq( + SET enable_seqscan = off; + SELECT i FROM tst ORDER BY v $operator '$queries[$i]' LIMIT $limit; + )); + my @actual_ids = split("\n", $actual); + my %actual_set = map { $_ => 1 } @actual_ids; + + my @expected_ids = split("\n", $expected[$i]); + + foreach (@expected_ids) + { + if (exists($actual_set{$_})) + { + $correct++; + } + $total++; + } + } + + cmp_ok($correct / $total, ">=", $min, $operator); +} + +# Initialize node +$node = PostgreSQL::Test::Cluster->new('node'); +$node->init; +$node->start; + +# Create table +$node->safe_psql("postgres", "CREATE EXTENSION vector;"); +$node->safe_psql("postgres", "CREATE TABLE tst (i int4, v minivec($dim));"); +$node->safe_psql("postgres", + "INSERT INTO tst SELECT i, ARRAY[$array_sql] FROM generate_series(1, 10000) i;" +); + +# Generate queries +for (1 .. 20) +{ + my @r = (); + for (1 .. $dim) + { + push(@r, rand()); + } + push(@queries, "[" . join(",", @r) . "]"); +} + +# Check each index type +my @operators = ("<->", "<#>", "<=>", "<+>"); +my @opclasses = ("minivec_l2_ops", "minivec_ip_ops", "minivec_cosine_ops", "minivec_l1_ops"); + +for my $i (0 .. $#operators) +{ + my $operator = $operators[$i]; + my $opclass = $opclasses[$i]; + + # Get exact results + @expected = (); + foreach (@queries) + { + my $res = $node->safe_psql("postgres", "SELECT i FROM tst ORDER BY v $operator '$_' LIMIT $limit;"); + push(@expected, $res); + } + + # Build index serially + $node->safe_psql("postgres", qq( + SET max_parallel_maintenance_workers = 0; + CREATE INDEX idx ON tst USING hnsw (v $opclass); + )); + + # Test approximate results + my $min = 0.98; + if ($operator eq '<=>') + { + $min = 0.65; + } + test_recall($min, $operator); + + $node->safe_psql("postgres", "DROP INDEX idx;"); + + # Build index in parallel in memory + my ($ret, $stdout, $stderr) = $node->psql("postgres", qq( + SET client_min_messages = DEBUG; + SET min_parallel_table_scan_size = 1; + CREATE INDEX idx ON tst USING hnsw (v $opclass); + )); + is($ret, 0, $stderr); + like($stderr, qr/using \d+ parallel workers/); + + # Test approximate results + test_recall($min, $operator); + + $node->safe_psql("postgres", "DROP INDEX idx;"); + + # Build index in parallel on disk + # Set parallel_workers on table to use workers with low maintenance_work_mem + ($ret, $stdout, $stderr) = $node->psql("postgres", qq( + ALTER TABLE tst SET (parallel_workers = 2); + SET client_min_messages = DEBUG; + SET maintenance_work_mem = '4MB'; + CREATE INDEX idx ON tst USING hnsw (v $opclass); + ALTER TABLE tst RESET (parallel_workers); + )); + is($ret, 0, $stderr); + like($stderr, qr/using \d+ parallel workers/); + like($stderr, qr/hnsw graph no longer fits into maintenance_work_mem/); + + $node->safe_psql("postgres", "DROP INDEX idx;"); +} + +done_testing();