mirror of
https://github.com/pgvector/pgvector.git
synced 2026-07-01 18:21:16 +08:00
Added avg aggregate for vector - closes #51
This commit is contained in:
181
src/vector.c
181
src/vector.c
@@ -23,6 +23,9 @@
|
||||
#define TYPALIGN_INT 'i'
|
||||
#endif
|
||||
|
||||
#define STATE_DIMS(x) (ARR_DIMS(x)[0] - 1)
|
||||
#define CreateStateDatums(dim) palloc(sizeof(Datum) * (dim + 1))
|
||||
|
||||
PG_MODULE_MAGIC;
|
||||
|
||||
/*
|
||||
@@ -82,6 +85,20 @@ CheckElement(float value)
|
||||
errmsg("infinite value not allowed in vector")));
|
||||
}
|
||||
|
||||
/*
|
||||
* Check state array
|
||||
*/
|
||||
static float8 *
|
||||
CheckStateArray(ArrayType *statearray, const char *caller)
|
||||
{
|
||||
if (ARR_NDIM(statearray) != 1 ||
|
||||
ARR_DIMS(statearray)[0] < 1 ||
|
||||
ARR_HASNULL(statearray) ||
|
||||
ARR_ELEMTYPE(statearray) != FLOAT8OID)
|
||||
elog(ERROR, "%s: expected state array", caller);
|
||||
return (float8 *) ARR_DATA_PTR(statearray);
|
||||
}
|
||||
|
||||
/*
|
||||
* Print vector - useful for debugging
|
||||
*/
|
||||
@@ -758,3 +775,167 @@ vector_cmp(PG_FUNCTION_ARGS)
|
||||
|
||||
PG_RETURN_INT32(vector_cmp_internal(a, b));
|
||||
}
|
||||
|
||||
/*
|
||||
* Accumulate vectors
|
||||
*/
|
||||
PGDLLEXPORT PG_FUNCTION_INFO_V1(vector_accum);
|
||||
Datum
|
||||
vector_accum(PG_FUNCTION_ARGS)
|
||||
{
|
||||
ArrayType *statearray = PG_GETARG_ARRAYTYPE_P(0);
|
||||
Vector *newval = PG_GETARG_VECTOR_P(1);
|
||||
float8 *statevalues;
|
||||
int16 dim;
|
||||
bool newarr;
|
||||
float8 n;
|
||||
Datum *statedatums;
|
||||
float *x = newval->x;
|
||||
ArrayType *result;
|
||||
|
||||
/* Check array before using */
|
||||
statevalues = CheckStateArray(statearray, "vector_accum");
|
||||
dim = STATE_DIMS(statearray);
|
||||
newarr = dim == 0;
|
||||
|
||||
if (newarr)
|
||||
dim = newval->dim;
|
||||
else
|
||||
CheckExpectedDim(dim, newval->dim);
|
||||
|
||||
n = statevalues[0] + 1.0;
|
||||
|
||||
statedatums = CreateStateDatums(dim);
|
||||
statedatums[0] = Float8GetDatumFast(n);
|
||||
|
||||
if (newarr)
|
||||
{
|
||||
for (int i = 0; i < dim; i++)
|
||||
statedatums[i + 1] = Float8GetDatumFast(x[i]);
|
||||
}
|
||||
else
|
||||
{
|
||||
for (int i = 0; i < dim; i++)
|
||||
{
|
||||
double v = statevalues[i + 1] + x[i];
|
||||
|
||||
if (isinf(v))
|
||||
float_overflow_error();
|
||||
|
||||
statedatums[i + 1] = Float8GetDatumFast(v);
|
||||
}
|
||||
}
|
||||
|
||||
/* Use float8 array like float4_accum */
|
||||
result = construct_array(statedatums, dim + 1,
|
||||
FLOAT8OID,
|
||||
sizeof(float8), FLOAT8PASSBYVAL, TYPALIGN_DOUBLE);
|
||||
|
||||
pfree(statedatums);
|
||||
|
||||
PG_RETURN_ARRAYTYPE_P(result);
|
||||
}
|
||||
|
||||
/*
|
||||
* Combine vectors
|
||||
*/
|
||||
PGDLLEXPORT PG_FUNCTION_INFO_V1(vector_combine);
|
||||
Datum
|
||||
vector_combine(PG_FUNCTION_ARGS)
|
||||
{
|
||||
ArrayType *statearray1 = PG_GETARG_ARRAYTYPE_P(0);
|
||||
ArrayType *statearray2 = PG_GETARG_ARRAYTYPE_P(1);
|
||||
float8 *statevalues1;
|
||||
float8 *statevalues2;
|
||||
float8 n;
|
||||
float8 n1;
|
||||
float8 n2;
|
||||
int16 dim;
|
||||
Datum *statedatums;
|
||||
ArrayType *result;
|
||||
|
||||
/* Check arrays before using */
|
||||
statevalues1 = CheckStateArray(statearray1, "vector_combine");
|
||||
statevalues2 = CheckStateArray(statearray2, "vector_combine");
|
||||
|
||||
n1 = statevalues1[0];
|
||||
n2 = statevalues2[0];
|
||||
|
||||
if (n1 == 0.0)
|
||||
{
|
||||
n = n2;
|
||||
dim = STATE_DIMS(statearray2);
|
||||
statedatums = CreateStateDatums(dim);
|
||||
for (int i = 1; i <= dim; i++)
|
||||
statedatums[i] = Float8GetDatumFast(statevalues2[i]);
|
||||
}
|
||||
else if (n2 == 0.0)
|
||||
{
|
||||
n = n1;
|
||||
dim = STATE_DIMS(statearray1);
|
||||
statedatums = CreateStateDatums(dim);
|
||||
for (int i = 1; i <= dim; i++)
|
||||
statedatums[i] = Float8GetDatumFast(statevalues1[i]);
|
||||
}
|
||||
else
|
||||
{
|
||||
n = n1 + n2;
|
||||
dim = STATE_DIMS(statearray1);
|
||||
CheckExpectedDim(dim, STATE_DIMS(statearray2));
|
||||
statedatums = CreateStateDatums(dim);
|
||||
for (int i = 1; i <= dim; i++)
|
||||
{
|
||||
double v = statevalues1[i] + statevalues2[i];
|
||||
|
||||
if (isinf(v))
|
||||
float_overflow_error();
|
||||
|
||||
statedatums[i] = Float8GetDatumFast(v);
|
||||
}
|
||||
}
|
||||
|
||||
statedatums[0] = Float8GetDatumFast(n);
|
||||
|
||||
result = construct_array(statedatums, dim + 1,
|
||||
FLOAT8OID,
|
||||
sizeof(float8), FLOAT8PASSBYVAL, TYPALIGN_DOUBLE);
|
||||
|
||||
pfree(statedatums);
|
||||
|
||||
PG_RETURN_ARRAYTYPE_P(result);
|
||||
}
|
||||
|
||||
/*
|
||||
* Average vectors
|
||||
*/
|
||||
PGDLLEXPORT PG_FUNCTION_INFO_V1(vector_avg);
|
||||
Datum
|
||||
vector_avg(PG_FUNCTION_ARGS)
|
||||
{
|
||||
ArrayType *statearray = PG_GETARG_ARRAYTYPE_P(0);
|
||||
float8 *statevalues;
|
||||
float8 n;
|
||||
uint16 dim;
|
||||
Vector *result;
|
||||
float v;
|
||||
|
||||
/* Check array before using */
|
||||
statevalues = CheckStateArray(statearray, "vector_avg");
|
||||
n = statevalues[0];
|
||||
|
||||
/* SQL defines AVG of no values to be NULL */
|
||||
if (n == 0.0)
|
||||
PG_RETURN_NULL();
|
||||
|
||||
/* Create vector */
|
||||
dim = STATE_DIMS(statearray);
|
||||
result = InitVector(dim);
|
||||
for (int i = 0; i < dim; i++)
|
||||
{
|
||||
v = statevalues[i + 1] / n;
|
||||
CheckElement(v);
|
||||
result->x[i] = v;
|
||||
}
|
||||
|
||||
PG_RETURN_POINTER(result);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user