Added avg aggregate for vector - closes #51

This commit is contained in:
Andrew Kane
2022-12-30 17:22:25 -08:00
parent b400ac0f36
commit e09f93cba7
7 changed files with 272 additions and 0 deletions

View File

@@ -23,6 +23,9 @@
#define TYPALIGN_INT 'i'
#endif
#define STATE_DIMS(x) (ARR_DIMS(x)[0] - 1)
#define CreateStateDatums(dim) palloc(sizeof(Datum) * (dim + 1))
PG_MODULE_MAGIC;
/*
@@ -82,6 +85,20 @@ CheckElement(float value)
errmsg("infinite value not allowed in vector")));
}
/*
* Check state array
*/
static float8 *
CheckStateArray(ArrayType *statearray, const char *caller)
{
if (ARR_NDIM(statearray) != 1 ||
ARR_DIMS(statearray)[0] < 1 ||
ARR_HASNULL(statearray) ||
ARR_ELEMTYPE(statearray) != FLOAT8OID)
elog(ERROR, "%s: expected state array", caller);
return (float8 *) ARR_DATA_PTR(statearray);
}
/*
* Print vector - useful for debugging
*/
@@ -758,3 +775,167 @@ vector_cmp(PG_FUNCTION_ARGS)
PG_RETURN_INT32(vector_cmp_internal(a, b));
}
/*
* Accumulate vectors
*/
PGDLLEXPORT PG_FUNCTION_INFO_V1(vector_accum);
Datum
vector_accum(PG_FUNCTION_ARGS)
{
ArrayType *statearray = PG_GETARG_ARRAYTYPE_P(0);
Vector *newval = PG_GETARG_VECTOR_P(1);
float8 *statevalues;
int16 dim;
bool newarr;
float8 n;
Datum *statedatums;
float *x = newval->x;
ArrayType *result;
/* Check array before using */
statevalues = CheckStateArray(statearray, "vector_accum");
dim = STATE_DIMS(statearray);
newarr = dim == 0;
if (newarr)
dim = newval->dim;
else
CheckExpectedDim(dim, newval->dim);
n = statevalues[0] + 1.0;
statedatums = CreateStateDatums(dim);
statedatums[0] = Float8GetDatumFast(n);
if (newarr)
{
for (int i = 0; i < dim; i++)
statedatums[i + 1] = Float8GetDatumFast(x[i]);
}
else
{
for (int i = 0; i < dim; i++)
{
double v = statevalues[i + 1] + x[i];
if (isinf(v))
float_overflow_error();
statedatums[i + 1] = Float8GetDatumFast(v);
}
}
/* Use float8 array like float4_accum */
result = construct_array(statedatums, dim + 1,
FLOAT8OID,
sizeof(float8), FLOAT8PASSBYVAL, TYPALIGN_DOUBLE);
pfree(statedatums);
PG_RETURN_ARRAYTYPE_P(result);
}
/*
* Combine vectors
*/
PGDLLEXPORT PG_FUNCTION_INFO_V1(vector_combine);
Datum
vector_combine(PG_FUNCTION_ARGS)
{
ArrayType *statearray1 = PG_GETARG_ARRAYTYPE_P(0);
ArrayType *statearray2 = PG_GETARG_ARRAYTYPE_P(1);
float8 *statevalues1;
float8 *statevalues2;
float8 n;
float8 n1;
float8 n2;
int16 dim;
Datum *statedatums;
ArrayType *result;
/* Check arrays before using */
statevalues1 = CheckStateArray(statearray1, "vector_combine");
statevalues2 = CheckStateArray(statearray2, "vector_combine");
n1 = statevalues1[0];
n2 = statevalues2[0];
if (n1 == 0.0)
{
n = n2;
dim = STATE_DIMS(statearray2);
statedatums = CreateStateDatums(dim);
for (int i = 1; i <= dim; i++)
statedatums[i] = Float8GetDatumFast(statevalues2[i]);
}
else if (n2 == 0.0)
{
n = n1;
dim = STATE_DIMS(statearray1);
statedatums = CreateStateDatums(dim);
for (int i = 1; i <= dim; i++)
statedatums[i] = Float8GetDatumFast(statevalues1[i]);
}
else
{
n = n1 + n2;
dim = STATE_DIMS(statearray1);
CheckExpectedDim(dim, STATE_DIMS(statearray2));
statedatums = CreateStateDatums(dim);
for (int i = 1; i <= dim; i++)
{
double v = statevalues1[i] + statevalues2[i];
if (isinf(v))
float_overflow_error();
statedatums[i] = Float8GetDatumFast(v);
}
}
statedatums[0] = Float8GetDatumFast(n);
result = construct_array(statedatums, dim + 1,
FLOAT8OID,
sizeof(float8), FLOAT8PASSBYVAL, TYPALIGN_DOUBLE);
pfree(statedatums);
PG_RETURN_ARRAYTYPE_P(result);
}
/*
* Average vectors
*/
PGDLLEXPORT PG_FUNCTION_INFO_V1(vector_avg);
Datum
vector_avg(PG_FUNCTION_ARGS)
{
ArrayType *statearray = PG_GETARG_ARRAYTYPE_P(0);
float8 *statevalues;
float8 n;
uint16 dim;
Vector *result;
float v;
/* Check array before using */
statevalues = CheckStateArray(statearray, "vector_avg");
n = statevalues[0];
/* SQL defines AVG of no values to be NULL */
if (n == 0.0)
PG_RETURN_NULL();
/* Create vector */
dim = STATE_DIMS(statearray);
result = InitVector(dim);
for (int i = 0; i < dim; i++)
{
v = statevalues[i + 1] / n;
CheckElement(v);
result->x[i] = v;
}
PG_RETURN_POINTER(result);
}