diff --git a/sql/vector--0.7.4--0.8.0.sql b/sql/vector--0.7.4--0.8.0.sql index 0c67ec8..84b711e 100644 --- a/sql/vector--0.7.4--0.8.0.sql +++ b/sql/vector--0.7.4--0.8.0.sql @@ -1,51 +1,7 @@ -- complain if script is sourced in psql, rather than via CREATE EXTENSION \echo Use "ALTER EXTENSION vector UPDATE TO '0.8.0'" to load this file. \quit -CREATE FUNCTION hnsw_minivec_support(internal) RETURNS internal - AS 'MODULE_PATHNAME' LANGUAGE C; - -CREATE TYPE minivec; - -CREATE FUNCTION minivec_in(cstring, oid, integer) RETURNS minivec - AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; - -CREATE FUNCTION minivec_out(minivec) RETURNS cstring - AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; - -CREATE FUNCTION minivec_typmod_in(cstring[]) RETURNS integer - AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; - -CREATE FUNCTION minivec_recv(internal, oid, integer) RETURNS minivec - AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; - -CREATE FUNCTION minivec_send(minivec) RETURNS bytea - AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; - -CREATE TYPE minivec ( - INPUT = minivec_in, - OUTPUT = minivec_out, - TYPMOD_IN = minivec_typmod_in, - RECEIVE = minivec_recv, - SEND = minivec_send, - STORAGE = external -); - -CREATE FUNCTION l2_distance(minivec, minivec) RETURNS float8 - AS 'MODULE_PATHNAME', 'minivec_l2_distance' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; - -CREATE FUNCTION minivec_l2_squared_distance(minivec, minivec) RETURNS float8 - AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; - -CREATE OPERATOR <-> ( - LEFTARG = minivec, RIGHTARG = minivec, PROCEDURE = l2_distance, - COMMUTATOR = '<->' -); - -CREATE OPERATOR CLASS minivec_l2_ops - FOR TYPE minivec USING hnsw AS - OPERATOR 1 <-> (minivec, minivec) FOR ORDER BY float_ops, - FUNCTION 1 minivec_l2_squared_distance(minivec, minivec), - FUNCTION 3 hnsw_minivec_support(internal); +-- TODO minivec functions CREATE FUNCTION array_to_sparsevec(integer[], integer, boolean) RETURNS sparsevec AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; diff --git a/sql/vector.sql b/sql/vector.sql index 1502293..30d74c7 100644 --- a/sql/vector.sql +++ b/sql/vector.sql @@ -683,11 +683,71 @@ CREATE TYPE minivec ( CREATE FUNCTION l2_distance(minivec, minivec) RETURNS float8 AS 'MODULE_PATHNAME', 'minivec_l2_distance' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; +CREATE FUNCTION inner_product(minivec, minivec) RETURNS float8 + AS 'MODULE_PATHNAME', 'minivec_inner_product' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + +CREATE FUNCTION cosine_distance(minivec, minivec) RETURNS float8 + AS 'MODULE_PATHNAME', 'minivec_cosine_distance' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + +CREATE FUNCTION l1_distance(minivec, minivec) RETURNS float8 + AS 'MODULE_PATHNAME', 'minivec_l1_distance' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + +CREATE FUNCTION vector_dims(minivec) RETURNS integer + AS 'MODULE_PATHNAME', 'minivec_vector_dims' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + +CREATE FUNCTION l2_norm(minivec) RETURNS float8 + AS 'MODULE_PATHNAME', 'minivec_l2_norm' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + +CREATE FUNCTION l2_normalize(minivec) RETURNS minivec + AS 'MODULE_PATHNAME', 'minivec_l2_normalize' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + +CREATE FUNCTION binary_quantize(minivec) RETURNS bit + AS 'MODULE_PATHNAME', 'minivec_binary_quantize' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + +CREATE FUNCTION subvector(minivec, int, int) RETURNS minivec + AS 'MODULE_PATHNAME', 'minivec_subvector' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + -- minivec private functions +CREATE FUNCTION minivec_add(minivec, minivec) RETURNS minivec + AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + +CREATE FUNCTION minivec_sub(minivec, minivec) RETURNS minivec + AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + +CREATE FUNCTION minivec_mul(minivec, minivec) RETURNS minivec + AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + +CREATE FUNCTION minivec_concat(minivec, minivec) RETURNS minivec + AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + +CREATE FUNCTION minivec_lt(minivec, minivec) RETURNS bool + AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + +CREATE FUNCTION minivec_le(minivec, minivec) RETURNS bool + AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + +CREATE FUNCTION minivec_eq(minivec, minivec) RETURNS bool + AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + +CREATE FUNCTION minivec_ne(minivec, minivec) RETURNS bool + AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + +CREATE FUNCTION minivec_ge(minivec, minivec) RETURNS bool + AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + +CREATE FUNCTION minivec_gt(minivec, minivec) RETURNS bool + AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + +CREATE FUNCTION minivec_cmp(minivec, minivec) RETURNS int4 + AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + CREATE FUNCTION minivec_l2_squared_distance(minivec, minivec) RETURNS float8 AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; +CREATE FUNCTION minivec_negative_inner_product(minivec, minivec) RETURNS float8 + AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + -- minivec operators CREATE OPERATOR <-> ( @@ -695,14 +755,111 @@ CREATE OPERATOR <-> ( COMMUTATOR = '<->' ); +CREATE OPERATOR <#> ( + LEFTARG = minivec, RIGHTARG = minivec, PROCEDURE = minivec_negative_inner_product, + COMMUTATOR = '<#>' +); + +CREATE OPERATOR <=> ( + LEFTARG = minivec, RIGHTARG = minivec, PROCEDURE = cosine_distance, + COMMUTATOR = '<=>' +); + +CREATE OPERATOR <+> ( + LEFTARG = minivec, RIGHTARG = minivec, PROCEDURE = l1_distance, + COMMUTATOR = '<+>' +); + +CREATE OPERATOR + ( + LEFTARG = minivec, RIGHTARG = minivec, PROCEDURE = minivec_add, + COMMUTATOR = + +); + +CREATE OPERATOR - ( + LEFTARG = minivec, RIGHTARG = minivec, PROCEDURE = minivec_sub +); + +CREATE OPERATOR * ( + LEFTARG = minivec, RIGHTARG = minivec, PROCEDURE = minivec_mul, + COMMUTATOR = * +); + +CREATE OPERATOR || ( + LEFTARG = minivec, RIGHTARG = minivec, PROCEDURE = minivec_concat +); + +CREATE OPERATOR < ( + LEFTARG = minivec, RIGHTARG = minivec, PROCEDURE = minivec_lt, + COMMUTATOR = > , NEGATOR = >= , + RESTRICT = scalarltsel, JOIN = scalarltjoinsel +); + +CREATE OPERATOR <= ( + LEFTARG = minivec, RIGHTARG = minivec, PROCEDURE = minivec_le, + COMMUTATOR = >= , NEGATOR = > , + RESTRICT = scalarlesel, JOIN = scalarlejoinsel +); + +CREATE OPERATOR = ( + LEFTARG = minivec, RIGHTARG = minivec, PROCEDURE = minivec_eq, + COMMUTATOR = = , NEGATOR = <> , + RESTRICT = eqsel, JOIN = eqjoinsel +); + +CREATE OPERATOR <> ( + LEFTARG = minivec, RIGHTARG = minivec, PROCEDURE = minivec_ne, + COMMUTATOR = <> , NEGATOR = = , + RESTRICT = eqsel, JOIN = eqjoinsel +); + +CREATE OPERATOR >= ( + LEFTARG = minivec, RIGHTARG = minivec, PROCEDURE = minivec_ge, + COMMUTATOR = <= , NEGATOR = < , + RESTRICT = scalargesel, JOIN = scalargejoinsel +); + +CREATE OPERATOR > ( + LEFTARG = minivec, RIGHTARG = minivec, PROCEDURE = minivec_gt, + COMMUTATOR = < , NEGATOR = <= , + RESTRICT = scalargtsel, JOIN = scalargtjoinsel +); + -- minivec op classes +CREATE OPERATOR CLASS minivec_ops + DEFAULT FOR TYPE minivec USING btree AS + OPERATOR 1 < , + OPERATOR 2 <= , + OPERATOR 3 = , + OPERATOR 4 >= , + OPERATOR 5 > , + FUNCTION 1 minivec_cmp(minivec, minivec); + CREATE OPERATOR CLASS minivec_l2_ops FOR TYPE minivec USING hnsw AS OPERATOR 1 <-> (minivec, minivec) FOR ORDER BY float_ops, FUNCTION 1 minivec_l2_squared_distance(minivec, minivec), FUNCTION 3 hnsw_minivec_support(internal); +CREATE OPERATOR CLASS minivec_ip_ops + FOR TYPE minivec USING hnsw AS + OPERATOR 1 <#> (minivec, minivec) FOR ORDER BY float_ops, + FUNCTION 1 minivec_negative_inner_product(minivec, minivec), + FUNCTION 3 hnsw_minivec_support(internal); + +CREATE OPERATOR CLASS minivec_cosine_ops + FOR TYPE minivec USING hnsw AS + OPERATOR 1 <=> (minivec, minivec) FOR ORDER BY float_ops, + FUNCTION 1 minivec_negative_inner_product(minivec, minivec), + FUNCTION 2 l2_norm(minivec), + FUNCTION 3 hnsw_minivec_support(internal); + +CREATE OPERATOR CLASS minivec_l1_ops + FOR TYPE minivec USING hnsw AS + OPERATOR 1 <+> (minivec, minivec) FOR ORDER BY float_ops, + FUNCTION 1 l1_distance(minivec, minivec), + FUNCTION 3 hnsw_minivec_support(internal); + -- bit functions CREATE FUNCTION hamming_distance(bit, bit) RETURNS float8 diff --git a/src/minivec.c b/src/minivec.c index d7ab626..de66ac8 100644 --- a/src/minivec.c +++ b/src/minivec.c @@ -2,6 +2,7 @@ #include +#include "bitvec.h" #include "catalog/pg_type.h" #include "common/shortest_dec.h" #include "fmgr.h" @@ -71,7 +72,7 @@ CheckElement(fp8 value) } /* - * Allocate and initialize a new half vector + * Allocate and initialize a new fp8 vector */ MiniVector * InitMiniVector(int dim) @@ -352,7 +353,6 @@ MinivecL2SquaredDistance(int dim, fp8 * ax, fp8 * bx) { float distance = 0.0; - /* Auto-vectorized */ for (int i = 0; i < dim; i++) { float diff = Fp8ToFloat4(ax[i]) - Fp8ToFloat4(bx[i]); @@ -392,3 +392,548 @@ minivec_l2_squared_distance(PG_FUNCTION_ARGS) PG_RETURN_FLOAT8((double) MinivecL2SquaredDistance(a->dim, a->x, b->x)); } + +static float +MinivecInnerProduct(int dim, fp8 * ax, fp8 * bx) +{ + float distance = 0.0; + + for (int i = 0; i < dim; i++) + distance += Fp8ToFloat4(ax[i]) * Fp8ToFloat4(bx[i]); + + return distance; +} + +/* + * Get the inner product of two fp8 vectors + */ +FUNCTION_PREFIX PG_FUNCTION_INFO_V1(minivec_inner_product); +Datum +minivec_inner_product(PG_FUNCTION_ARGS) +{ + MiniVector *a = PG_GETARG_MINIVEC_P(0); + MiniVector *b = PG_GETARG_MINIVEC_P(1); + + CheckDims(a, b); + + PG_RETURN_FLOAT8((double) MinivecInnerProduct(a->dim, a->x, b->x)); +} + +/* + * Get the negative inner product of two fp8 vectors + */ +FUNCTION_PREFIX PG_FUNCTION_INFO_V1(minivec_negative_inner_product); +Datum +minivec_negative_inner_product(PG_FUNCTION_ARGS) +{ + MiniVector *a = PG_GETARG_MINIVEC_P(0); + MiniVector *b = PG_GETARG_MINIVEC_P(1); + + CheckDims(a, b); + + PG_RETURN_FLOAT8((double) -MinivecInnerProduct(a->dim, a->x, b->x)); +} + +static double +MinivecCosineSimilarity(int dim, fp8 * ax, fp8 * bx) +{ + float similarity = 0.0; + float norma = 0.0; + float normb = 0.0; + + for (int i = 0; i < dim; i++) + { + float axi = Fp8ToFloat4(ax[i]); + float bxi = Fp8ToFloat4(bx[i]); + + similarity += axi * bxi; + norma += axi * axi; + normb += bxi * bxi; + } + + /* Use sqrt(a * b) over sqrt(a) * sqrt(b) */ + return (double) similarity / sqrt((double) norma * (double) normb); +} + +/* + * Get the cosine distance between two fp8 vectors + */ +FUNCTION_PREFIX PG_FUNCTION_INFO_V1(minivec_cosine_distance); +Datum +minivec_cosine_distance(PG_FUNCTION_ARGS) +{ + MiniVector *a = PG_GETARG_MINIVEC_P(0); + MiniVector *b = PG_GETARG_MINIVEC_P(1); + double similarity; + + CheckDims(a, b); + + similarity = MinivecCosineSimilarity(a->dim, a->x, b->x); + +#ifdef _MSC_VER + /* /fp:fast may not propagate NaN */ + if (isnan(similarity)) + PG_RETURN_FLOAT8(NAN); +#endif + + /* Keep in range */ + if (similarity > 1) + similarity = 1; + else if (similarity < -1) + similarity = -1; + + PG_RETURN_FLOAT8(1 - similarity); +} + +/* + * Get the distance for spherical k-means + * Currently uses angular distance since needs to satisfy triangle inequality + * Assumes inputs are unit vectors (skips norm) + */ +FUNCTION_PREFIX PG_FUNCTION_INFO_V1(minivec_spherical_distance); +Datum +minivec_spherical_distance(PG_FUNCTION_ARGS) +{ + MiniVector *a = PG_GETARG_MINIVEC_P(0); + MiniVector *b = PG_GETARG_MINIVEC_P(1); + double distance; + + CheckDims(a, b); + + distance = (double) MinivecInnerProduct(a->dim, a->x, b->x); + + /* Prevent NaN with acos with loss of precision */ + if (distance > 1) + distance = 1; + else if (distance < -1) + distance = -1; + + PG_RETURN_FLOAT8(acos(distance) / M_PI); +} + +static float +MinivecL1Distance(int dim, fp8 * ax, fp8 * bx) +{ + float distance = 0.0; + + /* Auto-vectorized */ + for (int i = 0; i < dim; i++) + distance += fabsf(Fp8ToFloat4(ax[i]) - Fp8ToFloat4(bx[i])); + + return distance; +} + +/* + * Get the L1 distance between two fp8 vectors + */ +FUNCTION_PREFIX PG_FUNCTION_INFO_V1(minivec_l1_distance); +Datum +minivec_l1_distance(PG_FUNCTION_ARGS) +{ + MiniVector *a = PG_GETARG_MINIVEC_P(0); + MiniVector *b = PG_GETARG_MINIVEC_P(1); + + CheckDims(a, b); + + PG_RETURN_FLOAT8((double) MinivecL1Distance(a->dim, a->x, b->x)); +} + +/* + * Get the dimensions of a fp8 vector + */ +FUNCTION_PREFIX PG_FUNCTION_INFO_V1(minivec_vector_dims); +Datum +minivec_vector_dims(PG_FUNCTION_ARGS) +{ + MiniVector *a = PG_GETARG_MINIVEC_P(0); + + PG_RETURN_INT32(a->dim); +} + +/* + * Get the L2 norm of a fp8 vector + */ +FUNCTION_PREFIX PG_FUNCTION_INFO_V1(minivec_l2_norm); +Datum +minivec_l2_norm(PG_FUNCTION_ARGS) +{ + MiniVector *a = PG_GETARG_MINIVEC_P(0); + fp8 *ax = a->x; + double norm = 0.0; + + /* Auto-vectorized */ + for (int i = 0; i < a->dim; i++) + { + double axi = (double) Fp8ToFloat4(ax[i]); + + norm += axi * axi; + } + + PG_RETURN_FLOAT8(sqrt(norm)); +} + +/* + * Normalize a fp8 vector with the L2 norm + */ +FUNCTION_PREFIX PG_FUNCTION_INFO_V1(minivec_l2_normalize); +Datum +minivec_l2_normalize(PG_FUNCTION_ARGS) +{ + MiniVector *a = PG_GETARG_MINIVEC_P(0); + fp8 *ax = a->x; + double norm = 0; + MiniVector *result; + fp8 *rx; + + result = InitMiniVector(a->dim); + rx = result->x; + + /* Auto-vectorized */ + for (int i = 0; i < a->dim; i++) + norm += (double) Fp8ToFloat4(ax[i]) * (double) Fp8ToFloat4(ax[i]); + + norm = sqrt(norm); + + /* Return zero vector for zero norm */ + if (norm > 0) + { + for (int i = 0; i < a->dim; i++) + rx[i] = Float4ToFp8Unchecked(Fp8ToFloat4(ax[i]) / norm); + + /* Check for overflow */ + for (int i = 0; i < a->dim; i++) + { + if (Fp8IsNan(rx[i])) + float_overflow_error(); + } + } + + PG_RETURN_POINTER(result); +} + +/* + * Add fp8 vectors + */ +FUNCTION_PREFIX PG_FUNCTION_INFO_V1(minivec_add); +Datum +minivec_add(PG_FUNCTION_ARGS) +{ + MiniVector *a = PG_GETARG_MINIVEC_P(0); + MiniVector *b = PG_GETARG_MINIVEC_P(1); + fp8 *ax = a->x; + fp8 *bx = b->x; + MiniVector *result; + fp8 *rx; + + CheckDims(a, b); + + result = InitMiniVector(a->dim); + rx = result->x; + + /* Auto-vectorized */ + for (int i = 0, imax = a->dim; i < imax; i++) + { +#ifdef FLT16_SUPPORT + rx[i] = ax[i] + bx[i]; +#else + rx[i] = Float4ToFp8Unchecked(Fp8ToFloat4(ax[i]) + Fp8ToFloat4(bx[i])); +#endif + } + + /* Check for overflow */ + for (int i = 0, imax = a->dim; i < imax; i++) + { + if (Fp8IsNan(rx[i])) + float_overflow_error(); + } + + PG_RETURN_POINTER(result); +} + +/* + * Subtract fp8 vectors + */ +FUNCTION_PREFIX PG_FUNCTION_INFO_V1(minivec_sub); +Datum +minivec_sub(PG_FUNCTION_ARGS) +{ + MiniVector *a = PG_GETARG_MINIVEC_P(0); + MiniVector *b = PG_GETARG_MINIVEC_P(1); + fp8 *ax = a->x; + fp8 *bx = b->x; + MiniVector *result; + fp8 *rx; + + CheckDims(a, b); + + result = InitMiniVector(a->dim); + rx = result->x; + + /* Auto-vectorized */ + for (int i = 0, imax = a->dim; i < imax; i++) + { +#ifdef FLT16_SUPPORT + rx[i] = ax[i] - bx[i]; +#else + rx[i] = Float4ToFp8Unchecked(Fp8ToFloat4(ax[i]) - Fp8ToFloat4(bx[i])); +#endif + } + + /* Check for overflow */ + for (int i = 0, imax = a->dim; i < imax; i++) + { + if (Fp8IsNan(rx[i])) + float_overflow_error(); + } + + PG_RETURN_POINTER(result); +} + +/* + * Multiply fp8 vectors + */ +FUNCTION_PREFIX PG_FUNCTION_INFO_V1(minivec_mul); +Datum +minivec_mul(PG_FUNCTION_ARGS) +{ + MiniVector *a = PG_GETARG_MINIVEC_P(0); + MiniVector *b = PG_GETARG_MINIVEC_P(1); + fp8 *ax = a->x; + fp8 *bx = b->x; + MiniVector *result; + fp8 *rx; + + CheckDims(a, b); + + result = InitMiniVector(a->dim); + rx = result->x; + + /* Auto-vectorized */ + for (int i = 0, imax = a->dim; i < imax; i++) + { +#ifdef FLT16_SUPPORT + rx[i] = ax[i] * bx[i]; +#else + rx[i] = Float4ToFp8Unchecked(Fp8ToFloat4(ax[i]) * Fp8ToFloat4(bx[i])); +#endif + } + + /* Check for overflow and underflow */ + for (int i = 0, imax = a->dim; i < imax; i++) + { + if (Fp8IsNan(rx[i])) + float_overflow_error(); + + if (Fp8IsZero(rx[i]) && !(Fp8IsZero(ax[i]) || Fp8IsZero(bx[i]))) + float_underflow_error(); + } + + PG_RETURN_POINTER(result); +} + +/* + * Concatenate fp8 vectors + */ +FUNCTION_PREFIX PG_FUNCTION_INFO_V1(minivec_concat); +Datum +minivec_concat(PG_FUNCTION_ARGS) +{ + MiniVector *a = PG_GETARG_MINIVEC_P(0); + MiniVector *b = PG_GETARG_MINIVEC_P(1); + MiniVector *result; + int dim = a->dim + b->dim; + + CheckDim(dim); + result = InitMiniVector(dim); + + for (int i = 0; i < a->dim; i++) + result->x[i] = a->x[i]; + + for (int i = 0; i < b->dim; i++) + result->x[i + a->dim] = b->x[i]; + + PG_RETURN_POINTER(result); +} + +/* + * Quantize a fp8 vector + */ +FUNCTION_PREFIX PG_FUNCTION_INFO_V1(minivec_binary_quantize); +Datum +minivec_binary_quantize(PG_FUNCTION_ARGS) +{ + MiniVector *a = PG_GETARG_MINIVEC_P(0); + fp8 *ax = a->x; + VarBit *result = InitBitVector(a->dim); + unsigned char *rx = VARBITS(result); + + for (int i = 0; i < a->dim; i++) + rx[i / 8] |= (Fp8ToFloat4(ax[i]) > 0) << (7 - (i % 8)); + + PG_RETURN_VARBIT_P(result); +} + +/* + * Get a subvector + */ +FUNCTION_PREFIX PG_FUNCTION_INFO_V1(minivec_subvector); +Datum +minivec_subvector(PG_FUNCTION_ARGS) +{ + MiniVector *a = PG_GETARG_MINIVEC_P(0); + int32 start = PG_GETARG_INT32(1); + int32 count = PG_GETARG_INT32(2); + int32 end; + fp8 *ax = a->x; + MiniVector *result; + int32 dim; + + if (count < 1) + ereport(ERROR, + (errcode(ERRCODE_DATA_EXCEPTION), + errmsg("minivec must have at least 1 dimension"))); + + /* + * Check if (start + count > a->dim), avoiding integer overflow. a->dim + * and count are both positive, so a->dim - count won't overflow. + */ + if (start > a->dim - count) + end = a->dim + 1; + else + end = start + count; + + /* Indexing starts at 1, like substring */ + if (start < 1) + start = 1; + else if (start > a->dim) + ereport(ERROR, + (errcode(ERRCODE_DATA_EXCEPTION), + errmsg("minivec must have at least 1 dimension"))); + + dim = end - start; + CheckDim(dim); + result = InitMiniVector(dim); + + for (int i = 0; i < dim; i++) + result->x[i] = ax[start - 1 + i]; + + PG_RETURN_POINTER(result); +} + +/* + * Internal helper to compare fp8 vectors + */ +static int +minivec_cmp_internal(MiniVector * a, MiniVector * b) +{ + int dim = Min(a->dim, b->dim); + + /* Check values before dimensions to be consistent with Postgres arrays */ + for (int i = 0; i < dim; i++) + { + if (Fp8ToFloat4(a->x[i]) < Fp8ToFloat4(b->x[i])) + return -1; + + if (Fp8ToFloat4(a->x[i]) > Fp8ToFloat4(b->x[i])) + return 1; + } + + if (a->dim < b->dim) + return -1; + + if (a->dim > b->dim) + return 1; + + return 0; +} + +/* + * Less than + */ +FUNCTION_PREFIX PG_FUNCTION_INFO_V1(minivec_lt); +Datum +minivec_lt(PG_FUNCTION_ARGS) +{ + MiniVector *a = PG_GETARG_MINIVEC_P(0); + MiniVector *b = PG_GETARG_MINIVEC_P(1); + + PG_RETURN_BOOL(minivec_cmp_internal(a, b) < 0); +} + +/* + * Less than or equal + */ +FUNCTION_PREFIX PG_FUNCTION_INFO_V1(minivec_le); +Datum +minivec_le(PG_FUNCTION_ARGS) +{ + MiniVector *a = PG_GETARG_MINIVEC_P(0); + MiniVector *b = PG_GETARG_MINIVEC_P(1); + + PG_RETURN_BOOL(minivec_cmp_internal(a, b) <= 0); +} + +/* + * Equal + */ +FUNCTION_PREFIX PG_FUNCTION_INFO_V1(minivec_eq); +Datum +minivec_eq(PG_FUNCTION_ARGS) +{ + MiniVector *a = PG_GETARG_MINIVEC_P(0); + MiniVector *b = PG_GETARG_MINIVEC_P(1); + + PG_RETURN_BOOL(minivec_cmp_internal(a, b) == 0); +} + +/* + * Not equal + */ +FUNCTION_PREFIX PG_FUNCTION_INFO_V1(minivec_ne); +Datum +minivec_ne(PG_FUNCTION_ARGS) +{ + MiniVector *a = PG_GETARG_MINIVEC_P(0); + MiniVector *b = PG_GETARG_MINIVEC_P(1); + + PG_RETURN_BOOL(minivec_cmp_internal(a, b) != 0); +} + +/* + * Greater than or equal + */ +FUNCTION_PREFIX PG_FUNCTION_INFO_V1(minivec_ge); +Datum +minivec_ge(PG_FUNCTION_ARGS) +{ + MiniVector *a = PG_GETARG_MINIVEC_P(0); + MiniVector *b = PG_GETARG_MINIVEC_P(1); + + PG_RETURN_BOOL(minivec_cmp_internal(a, b) >= 0); +} + +/* + * Greater than + */ +FUNCTION_PREFIX PG_FUNCTION_INFO_V1(minivec_gt); +Datum +minivec_gt(PG_FUNCTION_ARGS) +{ + MiniVector *a = PG_GETARG_MINIVEC_P(0); + MiniVector *b = PG_GETARG_MINIVEC_P(1); + + PG_RETURN_BOOL(minivec_cmp_internal(a, b) > 0); +} + +/* + * Compare fp8 vectors + */ +FUNCTION_PREFIX PG_FUNCTION_INFO_V1(minivec_cmp); +Datum +minivec_cmp(PG_FUNCTION_ARGS) +{ + MiniVector *a = PG_GETARG_MINIVEC_P(0); + MiniVector *b = PG_GETARG_MINIVEC_P(1); + + PG_RETURN_INT32(minivec_cmp_internal(a, b)); +} diff --git a/src/minivec.h b/src/minivec.h index 7acd813..5e1e2c2 100644 --- a/src/minivec.h +++ b/src/minivec.h @@ -31,6 +31,15 @@ Fp8IsNan(fp8 num) return (num & 0x7F) == 0x7F; } +/* + * Check if fp8 is zero + */ +static inline bool +Fp8IsZero(fp8 num) +{ + return num == 0; +} + float lookup[128] = {0, 0.00195312, 0.00390625, 0.00585938, 0.0078125, 0.00976562, 0.0117188, 0.0136719, 0.015625, 0.0175781, 0.0195312, 0.0214844, 0.0234375, 0.0253906, 0.0273438, 0.0292969, 0.03125, 0.0351562, 0.0390625, 0.0429688, 0.046875, 0.0507812, 0.0546875, 0.0585938, 0.0625, 0.0703125, 0.078125, 0.0859375, 0.09375, 0.101562, 0.109375, 0.117188, 0.125, 0.140625, 0.15625, 0.171875, 0.1875, 0.203125, 0.21875, 0.234375, 0.25, 0.28125, 0.3125, 0.34375, 0.375, 0.40625, 0.4375, 0.46875, 0.5, 0.5625, 0.625, 0.6875, 0.75, 0.8125, 0.875, 0.9375, 1, 1.125, 1.25, 1.375, 1.5, 1.625, 1.75, 1.875, 2, 2.25, 2.5, 2.75, 3, 3.25, 3.5, 3.75, 4, 4.5, 5, 5.5, 6, 6.5, 7, 7.5, 8, 9, 10, 11, 12, 13, 14, 15, 16, 18, 20, 22, 24, 26, 28, 30, 32, 36, 40, 44, 48, 52, 56, 60, 64, 72, 80, 88, 96, 104, 112, 120, 128, 144, 160, 176, 192, 208, 224, 240, 256, 288, 320, 352, 384, 416, 448, NAN}; /* diff --git a/test/sql/btree.sql b/test/sql/btree.sql index de583c3..6a4c69c 100644 --- a/test/sql/btree.sql +++ b/test/sql/btree.sql @@ -22,6 +22,17 @@ SELECT * FROM t ORDER BY val; DROP TABLE t; +-- minivec + +CREATE TABLE t (val minivec(3)); +INSERT INTO t (val) VALUES ('[0,0,0]'), ('[1,2,3]'), ('[1,1,1]'), (NULL); +CREATE INDEX ON t (val); + +SELECT * FROM t WHERE val = '[1,2,3]'; +SELECT * FROM t ORDER BY val; + +DROP TABLE t; + -- sparsevec CREATE TABLE t (val sparsevec(3)); diff --git a/test/sql/copy.sql b/test/sql/copy.sql index a9431dc..1012801 100644 --- a/test/sql/copy.sql +++ b/test/sql/copy.sql @@ -38,7 +38,7 @@ CREATE TABLE t2 (val minivec(3)); \copy t TO 'results/minivec.bin' WITH (FORMAT binary) \copy t2 FROM 'results/minivec.bin' WITH (FORMAT binary) ---SELECT * FROM t2 ORDER BY val; +SELECT * FROM t2 ORDER BY val; DROP TABLE t; DROP TABLE t2;