mirror of
https://github.com/pgvector/pgvector.git
synced 2026-06-06 05:51:21 +08:00
Added avg for half vectors [skip ci]
This commit is contained in:
@@ -902,6 +902,12 @@ l2_distance(halfvec, halfvec) → double precision | Euclidean distance | unrele
|
||||
quantize_binary(halfvec) → bit | quantize | unreleased
|
||||
subvector(halfvec, integer, integer) → halfvec | subvector | unreleased
|
||||
|
||||
### Halfvec Aggregate Functions
|
||||
|
||||
Function | Description | Added
|
||||
--- | --- | ---
|
||||
avg(halfvec) → halfvec | average | unreleased
|
||||
|
||||
### Bit Type
|
||||
|
||||
Each bit vector takes `dimensions / 8 + 8` bytes of storage. See the [Postgres docs](https://www.postgresql.org/docs/current/datatype-bit.html) for more info.
|
||||
|
||||
@@ -122,6 +122,21 @@ CREATE FUNCTION halfvec_negative_inner_product(halfvec, halfvec) RETURNS float8
|
||||
CREATE FUNCTION halfvec_spherical_distance(halfvec, halfvec) RETURNS float8
|
||||
AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE;
|
||||
|
||||
CREATE FUNCTION halfvec_accum(double precision[], halfvec) RETURNS double precision[]
|
||||
AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE;
|
||||
|
||||
CREATE FUNCTION halfvec_avg(double precision[]) RETURNS halfvec
|
||||
AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE;
|
||||
|
||||
CREATE AGGREGATE avg(halfvec) (
|
||||
SFUNC = halfvec_accum,
|
||||
STYPE = double precision[],
|
||||
FINALFUNC = halfvec_avg,
|
||||
COMBINEFUNC = vector_combine,
|
||||
INITCOND = '{0}',
|
||||
PARALLEL = SAFE
|
||||
);
|
||||
|
||||
CREATE FUNCTION halfvec(halfvec, integer, boolean) RETURNS halfvec
|
||||
AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE;
|
||||
|
||||
|
||||
@@ -417,6 +417,23 @@ CREATE FUNCTION halfvec_negative_inner_product(halfvec, halfvec) RETURNS float8
|
||||
CREATE FUNCTION halfvec_spherical_distance(halfvec, halfvec) RETURNS float8
|
||||
AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE;
|
||||
|
||||
CREATE FUNCTION halfvec_accum(double precision[], halfvec) RETURNS double precision[]
|
||||
AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE;
|
||||
|
||||
CREATE FUNCTION halfvec_avg(double precision[]) RETURNS halfvec
|
||||
AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE;
|
||||
|
||||
-- halfvec aggregates
|
||||
|
||||
CREATE AGGREGATE avg(halfvec) (
|
||||
SFUNC = halfvec_accum,
|
||||
STYPE = double precision[],
|
||||
FINALFUNC = halfvec_avg,
|
||||
COMBINEFUNC = vector_combine,
|
||||
INITCOND = '{0}',
|
||||
PARALLEL = SAFE
|
||||
);
|
||||
|
||||
-- halfvec cast functions
|
||||
|
||||
CREATE FUNCTION halfvec(halfvec, integer, boolean) RETURNS halfvec
|
||||
|
||||
112
src/halfvec.c
112
src/halfvec.c
@@ -23,6 +23,9 @@
|
||||
#define TYPALIGN_INT 'i'
|
||||
#endif
|
||||
|
||||
#define STATE_DIMS(x) (ARR_DIMS(x)[0] - 1)
|
||||
#define CreateStateDatums(dim) palloc(sizeof(Datum) * (dim + 1))
|
||||
|
||||
/*
|
||||
* Get a half from a message buffer
|
||||
*/
|
||||
@@ -146,6 +149,20 @@ halfvec_isspace(char ch)
|
||||
return false;
|
||||
}
|
||||
|
||||
/*
|
||||
* Check state array
|
||||
*/
|
||||
static float8 *
|
||||
CheckStateArray(ArrayType *statearray, const char *caller)
|
||||
{
|
||||
if (ARR_NDIM(statearray) != 1 ||
|
||||
ARR_DIMS(statearray)[0] < 1 ||
|
||||
ARR_HASNULL(statearray) ||
|
||||
ARR_ELEMTYPE(statearray) != FLOAT8OID)
|
||||
elog(ERROR, "%s: expected state array", caller);
|
||||
return (float8 *) ARR_DATA_PTR(statearray);
|
||||
}
|
||||
|
||||
#if PG_VERSION_NUM < 120003
|
||||
static pg_noinline void
|
||||
float_overflow_error(void)
|
||||
@@ -1016,3 +1033,98 @@ halfvec_cmp(PG_FUNCTION_ARGS)
|
||||
|
||||
PG_RETURN_INT32(halfvec_cmp_internal(a, b));
|
||||
}
|
||||
|
||||
/*
|
||||
* Accumulate half vectors
|
||||
*/
|
||||
PGDLLEXPORT PG_FUNCTION_INFO_V1(halfvec_accum);
|
||||
Datum
|
||||
halfvec_accum(PG_FUNCTION_ARGS)
|
||||
{
|
||||
ArrayType *statearray = PG_GETARG_ARRAYTYPE_P(0);
|
||||
HalfVector *newval = PG_GETARG_HALFVEC_P(1);
|
||||
float8 *statevalues;
|
||||
int16 dim;
|
||||
bool newarr;
|
||||
float8 n;
|
||||
Datum *statedatums;
|
||||
half *x = newval->x;
|
||||
ArrayType *result;
|
||||
|
||||
/* Check array before using */
|
||||
statevalues = CheckStateArray(statearray, "halfvec_accum");
|
||||
dim = STATE_DIMS(statearray);
|
||||
newarr = dim == 0;
|
||||
|
||||
if (newarr)
|
||||
dim = newval->dim;
|
||||
else
|
||||
CheckExpectedDim(dim, newval->dim);
|
||||
|
||||
n = statevalues[0] + 1.0;
|
||||
|
||||
statedatums = CreateStateDatums(dim);
|
||||
statedatums[0] = Float8GetDatum(n);
|
||||
|
||||
if (newarr)
|
||||
{
|
||||
for (int i = 0; i < dim; i++)
|
||||
statedatums[i + 1] = Float8GetDatum((double) HalfToFloat4(x[i]));
|
||||
}
|
||||
else
|
||||
{
|
||||
for (int i = 0; i < dim; i++)
|
||||
{
|
||||
double v = statevalues[i + 1] + (double) HalfToFloat4(x[i]);
|
||||
|
||||
/* Check for overflow */
|
||||
if (isinf(v))
|
||||
float_overflow_error();
|
||||
|
||||
statedatums[i + 1] = Float8GetDatum(v);
|
||||
}
|
||||
}
|
||||
|
||||
/* Use float8 array like float4_accum */
|
||||
result = construct_array(statedatums, dim + 1,
|
||||
FLOAT8OID,
|
||||
sizeof(float8), FLOAT8PASSBYVAL, TYPALIGN_DOUBLE);
|
||||
|
||||
pfree(statedatums);
|
||||
|
||||
PG_RETURN_ARRAYTYPE_P(result);
|
||||
}
|
||||
|
||||
/*
|
||||
* Average half vectors
|
||||
*/
|
||||
PGDLLEXPORT PG_FUNCTION_INFO_V1(halfvec_avg);
|
||||
Datum
|
||||
halfvec_avg(PG_FUNCTION_ARGS)
|
||||
{
|
||||
ArrayType *statearray = PG_GETARG_ARRAYTYPE_P(0);
|
||||
float8 *statevalues;
|
||||
float8 n;
|
||||
uint16 dim;
|
||||
HalfVector *result;
|
||||
|
||||
/* Check array before using */
|
||||
statevalues = CheckStateArray(statearray, "halfvec_avg");
|
||||
n = statevalues[0];
|
||||
|
||||
/* SQL defines AVG of no values to be NULL */
|
||||
if (n == 0.0)
|
||||
PG_RETURN_NULL();
|
||||
|
||||
/* Create half vector */
|
||||
dim = STATE_DIMS(statearray);
|
||||
CheckDim(dim);
|
||||
result = InitHalfVector(dim);
|
||||
for (int i = 0; i < dim; i++)
|
||||
{
|
||||
result->x[i] = Float4ToHalf(statevalues[i + 1] / n);
|
||||
CheckElement(result->x[i]);
|
||||
}
|
||||
|
||||
PG_RETURN_POINTER(result);
|
||||
}
|
||||
|
||||
@@ -1100,7 +1100,7 @@ vector_accum(PG_FUNCTION_ARGS)
|
||||
}
|
||||
|
||||
/*
|
||||
* Combine vectors
|
||||
* Combine vectors or half vectors
|
||||
*/
|
||||
PGDLLEXPORT PG_FUNCTION_INFO_V1(vector_combine);
|
||||
Datum
|
||||
|
||||
@@ -320,3 +320,31 @@ SELECT subvector('[1,2,3,4,5]'::halfvec, 3, -1);
|
||||
ERROR: halfvec must have at least 1 dimension
|
||||
SELECT subvector('[1,2,3,4,5]'::halfvec, -1, 2);
|
||||
ERROR: halfvec must have at least 1 dimension
|
||||
SELECT avg(v) FROM unnest(ARRAY['[1,2,3]'::halfvec, '[3,5,7]']) v;
|
||||
avg
|
||||
-----------
|
||||
[2,3.5,5]
|
||||
(1 row)
|
||||
|
||||
SELECT avg(v) FROM unnest(ARRAY['[1,2,3]'::halfvec, '[3,5,7]', NULL]) v;
|
||||
avg
|
||||
-----------
|
||||
[2,3.5,5]
|
||||
(1 row)
|
||||
|
||||
SELECT avg(v) FROM unnest(ARRAY[]::halfvec[]) v;
|
||||
avg
|
||||
-----
|
||||
|
||||
(1 row)
|
||||
|
||||
SELECT avg(v) FROM unnest(ARRAY['[1,2]'::halfvec, '[3]']) v;
|
||||
ERROR: expected 2 dimensions, not 1
|
||||
SELECT avg(v) FROM unnest(ARRAY['[65504]'::halfvec, '[65504]']) v;
|
||||
avg
|
||||
---------
|
||||
[65504]
|
||||
(1 row)
|
||||
|
||||
SELECT halfvec_avg(array_agg(n)) FROM generate_series(1, 16002) n;
|
||||
ERROR: halfvec cannot have more than 16000 dimensions
|
||||
|
||||
@@ -69,3 +69,10 @@ SELECT subvector('[1,2,3,4,5]'::halfvec, 3, 9);
|
||||
SELECT subvector('[1,2,3,4,5]'::halfvec, 1, 0);
|
||||
SELECT subvector('[1,2,3,4,5]'::halfvec, 3, -1);
|
||||
SELECT subvector('[1,2,3,4,5]'::halfvec, -1, 2);
|
||||
|
||||
SELECT avg(v) FROM unnest(ARRAY['[1,2,3]'::halfvec, '[3,5,7]']) v;
|
||||
SELECT avg(v) FROM unnest(ARRAY['[1,2,3]'::halfvec, '[3,5,7]', NULL]) v;
|
||||
SELECT avg(v) FROM unnest(ARRAY[]::halfvec[]) v;
|
||||
SELECT avg(v) FROM unnest(ARRAY['[1,2]'::halfvec, '[3]']) v;
|
||||
SELECT avg(v) FROM unnest(ARRAY['[65504]'::halfvec, '[65504]']) v;
|
||||
SELECT halfvec_avg(array_agg(n)) FROM generate_series(1, 16002) n;
|
||||
|
||||
Reference in New Issue
Block a user