Added avg aggregate for vector - closes #51

This commit is contained in:
Andrew Kane
2022-12-30 17:22:25 -08:00
parent b400ac0f36
commit e09f93cba7
7 changed files with 272 additions and 0 deletions

View File

@@ -3,6 +3,7 @@
- Changed text representation for vector elements to match `real`
- Changed storage for vector from `plain` to `extended` for new installations
- Improved accuracy of text parsing for certain inputs
- Added `avg` aggregate for vector
- Added experimental support for Windows
- Dropped support for Postgres 10

View File

@@ -3,3 +3,21 @@
-- requires Postgres 13+
-- ALTER TYPE vector SET (STORAGE = extended);
CREATE FUNCTION vector_accum(double precision[], vector) RETURNS double precision[]
AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE;
CREATE FUNCTION vector_avg(double precision[]) RETURNS vector
AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE;
CREATE FUNCTION vector_combine(double precision[], double precision[]) RETURNS double precision[]
AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE;
CREATE AGGREGATE avg(vector) (
SFUNC = vector_accum,
STYPE = double precision[],
FINALFUNC = vector_avg,
COMBINEFUNC = vector_combine,
INITCOND = '{0}',
PARALLEL = SAFE
);

View File

@@ -84,6 +84,26 @@ CREATE FUNCTION vector_negative_inner_product(vector, vector) RETURNS float8
CREATE FUNCTION vector_spherical_distance(vector, vector) RETURNS float8
AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE;
CREATE FUNCTION vector_accum(double precision[], vector) RETURNS double precision[]
AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE;
CREATE FUNCTION vector_avg(double precision[]) RETURNS vector
AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE;
CREATE FUNCTION vector_combine(double precision[], double precision[]) RETURNS double precision[]
AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE;
-- aggregates
CREATE AGGREGATE avg(vector) (
SFUNC = vector_accum,
STYPE = double precision[],
FINALFUNC = vector_avg,
COMBINEFUNC = vector_combine,
INITCOND = '{0}',
PARALLEL = SAFE
);
-- cast functions
CREATE FUNCTION vector(vector, integer, boolean) RETURNS vector

View File

@@ -23,6 +23,9 @@
#define TYPALIGN_INT 'i'
#endif
#define STATE_DIMS(x) (ARR_DIMS(x)[0] - 1)
#define CreateStateDatums(dim) palloc(sizeof(Datum) * (dim + 1))
PG_MODULE_MAGIC;
/*
@@ -82,6 +85,20 @@ CheckElement(float value)
errmsg("infinite value not allowed in vector")));
}
/*
* Check state array
*/
static float8 *
CheckStateArray(ArrayType *statearray, const char *caller)
{
if (ARR_NDIM(statearray) != 1 ||
ARR_DIMS(statearray)[0] < 1 ||
ARR_HASNULL(statearray) ||
ARR_ELEMTYPE(statearray) != FLOAT8OID)
elog(ERROR, "%s: expected state array", caller);
return (float8 *) ARR_DATA_PTR(statearray);
}
/*
* Print vector - useful for debugging
*/
@@ -758,3 +775,167 @@ vector_cmp(PG_FUNCTION_ARGS)
PG_RETURN_INT32(vector_cmp_internal(a, b));
}
/*
* Accumulate vectors
*/
PGDLLEXPORT PG_FUNCTION_INFO_V1(vector_accum);
Datum
vector_accum(PG_FUNCTION_ARGS)
{
ArrayType *statearray = PG_GETARG_ARRAYTYPE_P(0);
Vector *newval = PG_GETARG_VECTOR_P(1);
float8 *statevalues;
int16 dim;
bool newarr;
float8 n;
Datum *statedatums;
float *x = newval->x;
ArrayType *result;
/* Check array before using */
statevalues = CheckStateArray(statearray, "vector_accum");
dim = STATE_DIMS(statearray);
newarr = dim == 0;
if (newarr)
dim = newval->dim;
else
CheckExpectedDim(dim, newval->dim);
n = statevalues[0] + 1.0;
statedatums = CreateStateDatums(dim);
statedatums[0] = Float8GetDatumFast(n);
if (newarr)
{
for (int i = 0; i < dim; i++)
statedatums[i + 1] = Float8GetDatumFast(x[i]);
}
else
{
for (int i = 0; i < dim; i++)
{
double v = statevalues[i + 1] + x[i];
if (isinf(v))
float_overflow_error();
statedatums[i + 1] = Float8GetDatumFast(v);
}
}
/* Use float8 array like float4_accum */
result = construct_array(statedatums, dim + 1,
FLOAT8OID,
sizeof(float8), FLOAT8PASSBYVAL, TYPALIGN_DOUBLE);
pfree(statedatums);
PG_RETURN_ARRAYTYPE_P(result);
}
/*
* Combine vectors
*/
PGDLLEXPORT PG_FUNCTION_INFO_V1(vector_combine);
Datum
vector_combine(PG_FUNCTION_ARGS)
{
ArrayType *statearray1 = PG_GETARG_ARRAYTYPE_P(0);
ArrayType *statearray2 = PG_GETARG_ARRAYTYPE_P(1);
float8 *statevalues1;
float8 *statevalues2;
float8 n;
float8 n1;
float8 n2;
int16 dim;
Datum *statedatums;
ArrayType *result;
/* Check arrays before using */
statevalues1 = CheckStateArray(statearray1, "vector_combine");
statevalues2 = CheckStateArray(statearray2, "vector_combine");
n1 = statevalues1[0];
n2 = statevalues2[0];
if (n1 == 0.0)
{
n = n2;
dim = STATE_DIMS(statearray2);
statedatums = CreateStateDatums(dim);
for (int i = 1; i <= dim; i++)
statedatums[i] = Float8GetDatumFast(statevalues2[i]);
}
else if (n2 == 0.0)
{
n = n1;
dim = STATE_DIMS(statearray1);
statedatums = CreateStateDatums(dim);
for (int i = 1; i <= dim; i++)
statedatums[i] = Float8GetDatumFast(statevalues1[i]);
}
else
{
n = n1 + n2;
dim = STATE_DIMS(statearray1);
CheckExpectedDim(dim, STATE_DIMS(statearray2));
statedatums = CreateStateDatums(dim);
for (int i = 1; i <= dim; i++)
{
double v = statevalues1[i] + statevalues2[i];
if (isinf(v))
float_overflow_error();
statedatums[i] = Float8GetDatumFast(v);
}
}
statedatums[0] = Float8GetDatumFast(n);
result = construct_array(statedatums, dim + 1,
FLOAT8OID,
sizeof(float8), FLOAT8PASSBYVAL, TYPALIGN_DOUBLE);
pfree(statedatums);
PG_RETURN_ARRAYTYPE_P(result);
}
/*
* Average vectors
*/
PGDLLEXPORT PG_FUNCTION_INFO_V1(vector_avg);
Datum
vector_avg(PG_FUNCTION_ARGS)
{
ArrayType *statearray = PG_GETARG_ARRAYTYPE_P(0);
float8 *statevalues;
float8 n;
uint16 dim;
Vector *result;
float v;
/* Check array before using */
statevalues = CheckStateArray(statearray, "vector_avg");
n = statevalues[0];
/* SQL defines AVG of no values to be NULL */
if (n == 0.0)
PG_RETURN_NULL();
/* Create vector */
dim = STATE_DIMS(statearray);
result = InitVector(dim);
for (int i = 0; i < dim; i++)
{
v = statevalues[i + 1] / n;
CheckElement(v);
result->x[i] = v;
}
PG_RETURN_POINTER(result);
}

View File

@@ -52,3 +52,23 @@ SELECT cosine_distance('[1,2]', '[0,0]');
SELECT cosine_distance('[1,2]', '[3]');
ERROR: different vector dimensions 2 and 1
SELECT avg(v) FROM unnest(ARRAY['[1,2,3]'::vector, '[3,5,7]']) v;
avg
-----------
[2,3.5,5]
(1 row)
SELECT avg(v) FROM unnest(ARRAY['[1,2,3]'::vector, '[3,5,7]', NULL]) v;
avg
-----------
[2,3.5,5]
(1 row)
SELECT avg(v) FROM unnest(ARRAY[]::vector[]) v;
avg
-----
(1 row)
SELECT avg(v) FROM unnest(ARRAY['[1,2]'::vector, '[3]']) v;
ERROR: expected 2 dimensions, not 1

View File

@@ -13,3 +13,8 @@ SELECT inner_product('[1,2]', '[3]');
SELECT round(cosine_distance('[1,2]', '[2,4]')::numeric, 5);
SELECT cosine_distance('[1,2]', '[0,0]');
SELECT cosine_distance('[1,2]', '[3]');
SELECT avg(v) FROM unnest(ARRAY['[1,2,3]'::vector, '[3,5,7]']) v;
SELECT avg(v) FROM unnest(ARRAY['[1,2,3]'::vector, '[3,5,7]', NULL]) v;
SELECT avg(v) FROM unnest(ARRAY[]::vector[]) v;
SELECT avg(v) FROM unnest(ARRAY['[1,2]'::vector, '[3]']) v;

27
test/t/008_avg.pl Normal file
View File

@@ -0,0 +1,27 @@
use strict;
use warnings;
use PostgresNode;
use TestLib;
use Test::More tests => 4;
# Initialize node
my $node = get_new_node('node');
$node->init;
$node->start;
# Create table
$node->safe_psql("postgres", "CREATE EXTENSION vector;");
$node->safe_psql("postgres", "CREATE TABLE tst (v vector(3));");
$node->safe_psql("postgres",
"INSERT INTO tst SELECT ARRAY[1.01 + random(), 2.01 + random(), 3.01 + random()] FROM generate_series(1, 1000000) i;"
);
# Test avg
my $avg = $node->safe_psql("postgres", "SELECT AVG(v) FROM tst;");
like($avg, qr/\[1\.5/);
like($avg, qr/,2\.5/);
like($avg, qr/,3\.5/);
# Test explain
my $explain = $node->safe_psql("postgres", "EXPLAIN SELECT AVG(v) FROM tst;");
like($explain, qr/Partial Aggregate/);