Added avg for half vectors [skip ci]

This commit is contained in:
Andrew Kane
2024-04-14 15:11:11 -07:00
parent 92d08bb6f5
commit e146f3cfb6
7 changed files with 186 additions and 1 deletions

View File

@@ -902,6 +902,12 @@ l2_distance(halfvec, halfvec) → double precision | Euclidean distance | unrele
quantize_binary(halfvec) → bit | quantize | unreleased
subvector(halfvec, integer, integer) → halfvec | subvector | unreleased
### Halfvec Aggregate Functions
Function | Description | Added
--- | --- | ---
avg(halfvec) → halfvec | average | unreleased
### Bit Type
Each bit vector takes `dimensions / 8 + 8` bytes of storage. See the [Postgres docs](https://www.postgresql.org/docs/current/datatype-bit.html) for more info.

View File

@@ -122,6 +122,21 @@ CREATE FUNCTION halfvec_negative_inner_product(halfvec, halfvec) RETURNS float8
CREATE FUNCTION halfvec_spherical_distance(halfvec, halfvec) RETURNS float8
AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE;
CREATE FUNCTION halfvec_accum(double precision[], halfvec) RETURNS double precision[]
AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE;
CREATE FUNCTION halfvec_avg(double precision[]) RETURNS halfvec
AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE;
CREATE AGGREGATE avg(halfvec) (
SFUNC = halfvec_accum,
STYPE = double precision[],
FINALFUNC = halfvec_avg,
COMBINEFUNC = vector_combine,
INITCOND = '{0}',
PARALLEL = SAFE
);
CREATE FUNCTION halfvec(halfvec, integer, boolean) RETURNS halfvec
AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE;

View File

@@ -417,6 +417,23 @@ CREATE FUNCTION halfvec_negative_inner_product(halfvec, halfvec) RETURNS float8
CREATE FUNCTION halfvec_spherical_distance(halfvec, halfvec) RETURNS float8
AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE;
CREATE FUNCTION halfvec_accum(double precision[], halfvec) RETURNS double precision[]
AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE;
CREATE FUNCTION halfvec_avg(double precision[]) RETURNS halfvec
AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE;
-- halfvec aggregates
CREATE AGGREGATE avg(halfvec) (
SFUNC = halfvec_accum,
STYPE = double precision[],
FINALFUNC = halfvec_avg,
COMBINEFUNC = vector_combine,
INITCOND = '{0}',
PARALLEL = SAFE
);
-- halfvec cast functions
CREATE FUNCTION halfvec(halfvec, integer, boolean) RETURNS halfvec

View File

@@ -23,6 +23,9 @@
#define TYPALIGN_INT 'i'
#endif
#define STATE_DIMS(x) (ARR_DIMS(x)[0] - 1)
#define CreateStateDatums(dim) palloc(sizeof(Datum) * (dim + 1))
/*
* Get a half from a message buffer
*/
@@ -146,6 +149,20 @@ halfvec_isspace(char ch)
return false;
}
/*
* Check state array
*/
static float8 *
CheckStateArray(ArrayType *statearray, const char *caller)
{
if (ARR_NDIM(statearray) != 1 ||
ARR_DIMS(statearray)[0] < 1 ||
ARR_HASNULL(statearray) ||
ARR_ELEMTYPE(statearray) != FLOAT8OID)
elog(ERROR, "%s: expected state array", caller);
return (float8 *) ARR_DATA_PTR(statearray);
}
#if PG_VERSION_NUM < 120003
static pg_noinline void
float_overflow_error(void)
@@ -1016,3 +1033,98 @@ halfvec_cmp(PG_FUNCTION_ARGS)
PG_RETURN_INT32(halfvec_cmp_internal(a, b));
}
/*
* Accumulate half vectors
*/
PGDLLEXPORT PG_FUNCTION_INFO_V1(halfvec_accum);
Datum
halfvec_accum(PG_FUNCTION_ARGS)
{
ArrayType *statearray = PG_GETARG_ARRAYTYPE_P(0);
HalfVector *newval = PG_GETARG_HALFVEC_P(1);
float8 *statevalues;
int16 dim;
bool newarr;
float8 n;
Datum *statedatums;
half *x = newval->x;
ArrayType *result;
/* Check array before using */
statevalues = CheckStateArray(statearray, "halfvec_accum");
dim = STATE_DIMS(statearray);
newarr = dim == 0;
if (newarr)
dim = newval->dim;
else
CheckExpectedDim(dim, newval->dim);
n = statevalues[0] + 1.0;
statedatums = CreateStateDatums(dim);
statedatums[0] = Float8GetDatum(n);
if (newarr)
{
for (int i = 0; i < dim; i++)
statedatums[i + 1] = Float8GetDatum((double) HalfToFloat4(x[i]));
}
else
{
for (int i = 0; i < dim; i++)
{
double v = statevalues[i + 1] + (double) HalfToFloat4(x[i]);
/* Check for overflow */
if (isinf(v))
float_overflow_error();
statedatums[i + 1] = Float8GetDatum(v);
}
}
/* Use float8 array like float4_accum */
result = construct_array(statedatums, dim + 1,
FLOAT8OID,
sizeof(float8), FLOAT8PASSBYVAL, TYPALIGN_DOUBLE);
pfree(statedatums);
PG_RETURN_ARRAYTYPE_P(result);
}
/*
* Average half vectors
*/
PGDLLEXPORT PG_FUNCTION_INFO_V1(halfvec_avg);
Datum
halfvec_avg(PG_FUNCTION_ARGS)
{
ArrayType *statearray = PG_GETARG_ARRAYTYPE_P(0);
float8 *statevalues;
float8 n;
uint16 dim;
HalfVector *result;
/* Check array before using */
statevalues = CheckStateArray(statearray, "halfvec_avg");
n = statevalues[0];
/* SQL defines AVG of no values to be NULL */
if (n == 0.0)
PG_RETURN_NULL();
/* Create half vector */
dim = STATE_DIMS(statearray);
CheckDim(dim);
result = InitHalfVector(dim);
for (int i = 0; i < dim; i++)
{
result->x[i] = Float4ToHalf(statevalues[i + 1] / n);
CheckElement(result->x[i]);
}
PG_RETURN_POINTER(result);
}

View File

@@ -1100,7 +1100,7 @@ vector_accum(PG_FUNCTION_ARGS)
}
/*
* Combine vectors
* Combine vectors or half vectors
*/
PGDLLEXPORT PG_FUNCTION_INFO_V1(vector_combine);
Datum

View File

@@ -320,3 +320,31 @@ SELECT subvector('[1,2,3,4,5]'::halfvec, 3, -1);
ERROR: halfvec must have at least 1 dimension
SELECT subvector('[1,2,3,4,5]'::halfvec, -1, 2);
ERROR: halfvec must have at least 1 dimension
SELECT avg(v) FROM unnest(ARRAY['[1,2,3]'::halfvec, '[3,5,7]']) v;
avg
-----------
[2,3.5,5]
(1 row)
SELECT avg(v) FROM unnest(ARRAY['[1,2,3]'::halfvec, '[3,5,7]', NULL]) v;
avg
-----------
[2,3.5,5]
(1 row)
SELECT avg(v) FROM unnest(ARRAY[]::halfvec[]) v;
avg
-----
(1 row)
SELECT avg(v) FROM unnest(ARRAY['[1,2]'::halfvec, '[3]']) v;
ERROR: expected 2 dimensions, not 1
SELECT avg(v) FROM unnest(ARRAY['[65504]'::halfvec, '[65504]']) v;
avg
---------
[65504]
(1 row)
SELECT halfvec_avg(array_agg(n)) FROM generate_series(1, 16002) n;
ERROR: halfvec cannot have more than 16000 dimensions

View File

@@ -69,3 +69,10 @@ SELECT subvector('[1,2,3,4,5]'::halfvec, 3, 9);
SELECT subvector('[1,2,3,4,5]'::halfvec, 1, 0);
SELECT subvector('[1,2,3,4,5]'::halfvec, 3, -1);
SELECT subvector('[1,2,3,4,5]'::halfvec, -1, 2);
SELECT avg(v) FROM unnest(ARRAY['[1,2,3]'::halfvec, '[3,5,7]']) v;
SELECT avg(v) FROM unnest(ARRAY['[1,2,3]'::halfvec, '[3,5,7]', NULL]) v;
SELECT avg(v) FROM unnest(ARRAY[]::halfvec[]) v;
SELECT avg(v) FROM unnest(ARRAY['[1,2]'::halfvec, '[3]']) v;
SELECT avg(v) FROM unnest(ARRAY['[65504]'::halfvec, '[65504]']) v;
SELECT halfvec_avg(array_agg(n)) FROM generate_series(1, 16002) n;