mirror of
https://github.com/pgvector/pgvector.git
synced 2026-07-02 10:40:57 +08:00
Added avg for half vectors [skip ci]
This commit is contained in:
112
src/halfvec.c
112
src/halfvec.c
@@ -23,6 +23,9 @@
|
||||
#define TYPALIGN_INT 'i'
|
||||
#endif
|
||||
|
||||
#define STATE_DIMS(x) (ARR_DIMS(x)[0] - 1)
|
||||
#define CreateStateDatums(dim) palloc(sizeof(Datum) * (dim + 1))
|
||||
|
||||
/*
|
||||
* Get a half from a message buffer
|
||||
*/
|
||||
@@ -146,6 +149,20 @@ halfvec_isspace(char ch)
|
||||
return false;
|
||||
}
|
||||
|
||||
/*
|
||||
* Check state array
|
||||
*/
|
||||
static float8 *
|
||||
CheckStateArray(ArrayType *statearray, const char *caller)
|
||||
{
|
||||
if (ARR_NDIM(statearray) != 1 ||
|
||||
ARR_DIMS(statearray)[0] < 1 ||
|
||||
ARR_HASNULL(statearray) ||
|
||||
ARR_ELEMTYPE(statearray) != FLOAT8OID)
|
||||
elog(ERROR, "%s: expected state array", caller);
|
||||
return (float8 *) ARR_DATA_PTR(statearray);
|
||||
}
|
||||
|
||||
#if PG_VERSION_NUM < 120003
|
||||
static pg_noinline void
|
||||
float_overflow_error(void)
|
||||
@@ -1016,3 +1033,98 @@ halfvec_cmp(PG_FUNCTION_ARGS)
|
||||
|
||||
PG_RETURN_INT32(halfvec_cmp_internal(a, b));
|
||||
}
|
||||
|
||||
/*
|
||||
* Accumulate half vectors
|
||||
*/
|
||||
PGDLLEXPORT PG_FUNCTION_INFO_V1(halfvec_accum);
|
||||
Datum
|
||||
halfvec_accum(PG_FUNCTION_ARGS)
|
||||
{
|
||||
ArrayType *statearray = PG_GETARG_ARRAYTYPE_P(0);
|
||||
HalfVector *newval = PG_GETARG_HALFVEC_P(1);
|
||||
float8 *statevalues;
|
||||
int16 dim;
|
||||
bool newarr;
|
||||
float8 n;
|
||||
Datum *statedatums;
|
||||
half *x = newval->x;
|
||||
ArrayType *result;
|
||||
|
||||
/* Check array before using */
|
||||
statevalues = CheckStateArray(statearray, "halfvec_accum");
|
||||
dim = STATE_DIMS(statearray);
|
||||
newarr = dim == 0;
|
||||
|
||||
if (newarr)
|
||||
dim = newval->dim;
|
||||
else
|
||||
CheckExpectedDim(dim, newval->dim);
|
||||
|
||||
n = statevalues[0] + 1.0;
|
||||
|
||||
statedatums = CreateStateDatums(dim);
|
||||
statedatums[0] = Float8GetDatum(n);
|
||||
|
||||
if (newarr)
|
||||
{
|
||||
for (int i = 0; i < dim; i++)
|
||||
statedatums[i + 1] = Float8GetDatum((double) HalfToFloat4(x[i]));
|
||||
}
|
||||
else
|
||||
{
|
||||
for (int i = 0; i < dim; i++)
|
||||
{
|
||||
double v = statevalues[i + 1] + (double) HalfToFloat4(x[i]);
|
||||
|
||||
/* Check for overflow */
|
||||
if (isinf(v))
|
||||
float_overflow_error();
|
||||
|
||||
statedatums[i + 1] = Float8GetDatum(v);
|
||||
}
|
||||
}
|
||||
|
||||
/* Use float8 array like float4_accum */
|
||||
result = construct_array(statedatums, dim + 1,
|
||||
FLOAT8OID,
|
||||
sizeof(float8), FLOAT8PASSBYVAL, TYPALIGN_DOUBLE);
|
||||
|
||||
pfree(statedatums);
|
||||
|
||||
PG_RETURN_ARRAYTYPE_P(result);
|
||||
}
|
||||
|
||||
/*
|
||||
* Average half vectors
|
||||
*/
|
||||
PGDLLEXPORT PG_FUNCTION_INFO_V1(halfvec_avg);
|
||||
Datum
|
||||
halfvec_avg(PG_FUNCTION_ARGS)
|
||||
{
|
||||
ArrayType *statearray = PG_GETARG_ARRAYTYPE_P(0);
|
||||
float8 *statevalues;
|
||||
float8 n;
|
||||
uint16 dim;
|
||||
HalfVector *result;
|
||||
|
||||
/* Check array before using */
|
||||
statevalues = CheckStateArray(statearray, "halfvec_avg");
|
||||
n = statevalues[0];
|
||||
|
||||
/* SQL defines AVG of no values to be NULL */
|
||||
if (n == 0.0)
|
||||
PG_RETURN_NULL();
|
||||
|
||||
/* Create half vector */
|
||||
dim = STATE_DIMS(statearray);
|
||||
CheckDim(dim);
|
||||
result = InitHalfVector(dim);
|
||||
for (int i = 0; i < dim; i++)
|
||||
{
|
||||
result->x[i] = Float4ToHalf(statevalues[i + 1] / n);
|
||||
CheckElement(result->x[i]);
|
||||
}
|
||||
|
||||
PG_RETURN_POINTER(result);
|
||||
}
|
||||
|
||||
@@ -1100,7 +1100,7 @@ vector_accum(PG_FUNCTION_ARGS)
|
||||
}
|
||||
|
||||
/*
|
||||
* Combine vectors
|
||||
* Combine vectors or half vectors
|
||||
*/
|
||||
PGDLLEXPORT PG_FUNCTION_INFO_V1(vector_combine);
|
||||
Datum
|
||||
|
||||
Reference in New Issue
Block a user