mirror of
https://github.com/pgvector/pgvector.git
synced 2026-06-06 05:51:21 +08:00
Added avg aggregate for vector - closes #51
This commit is contained in:
@@ -3,6 +3,7 @@
|
||||
- Changed text representation for vector elements to match `real`
|
||||
- Changed storage for vector from `plain` to `extended` for new installations
|
||||
- Improved accuracy of text parsing for certain inputs
|
||||
- Added `avg` aggregate for vector
|
||||
- Added experimental support for Windows
|
||||
- Dropped support for Postgres 10
|
||||
|
||||
|
||||
@@ -3,3 +3,21 @@
|
||||
|
||||
-- requires Postgres 13+
|
||||
-- ALTER TYPE vector SET (STORAGE = extended);
|
||||
|
||||
CREATE FUNCTION vector_accum(double precision[], vector) RETURNS double precision[]
|
||||
AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE;
|
||||
|
||||
CREATE FUNCTION vector_avg(double precision[]) RETURNS vector
|
||||
AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE;
|
||||
|
||||
CREATE FUNCTION vector_combine(double precision[], double precision[]) RETURNS double precision[]
|
||||
AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE;
|
||||
|
||||
CREATE AGGREGATE avg(vector) (
|
||||
SFUNC = vector_accum,
|
||||
STYPE = double precision[],
|
||||
FINALFUNC = vector_avg,
|
||||
COMBINEFUNC = vector_combine,
|
||||
INITCOND = '{0}',
|
||||
PARALLEL = SAFE
|
||||
);
|
||||
|
||||
@@ -84,6 +84,26 @@ CREATE FUNCTION vector_negative_inner_product(vector, vector) RETURNS float8
|
||||
CREATE FUNCTION vector_spherical_distance(vector, vector) RETURNS float8
|
||||
AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE;
|
||||
|
||||
CREATE FUNCTION vector_accum(double precision[], vector) RETURNS double precision[]
|
||||
AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE;
|
||||
|
||||
CREATE FUNCTION vector_avg(double precision[]) RETURNS vector
|
||||
AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE;
|
||||
|
||||
CREATE FUNCTION vector_combine(double precision[], double precision[]) RETURNS double precision[]
|
||||
AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE;
|
||||
|
||||
-- aggregates
|
||||
|
||||
CREATE AGGREGATE avg(vector) (
|
||||
SFUNC = vector_accum,
|
||||
STYPE = double precision[],
|
||||
FINALFUNC = vector_avg,
|
||||
COMBINEFUNC = vector_combine,
|
||||
INITCOND = '{0}',
|
||||
PARALLEL = SAFE
|
||||
);
|
||||
|
||||
-- cast functions
|
||||
|
||||
CREATE FUNCTION vector(vector, integer, boolean) RETURNS vector
|
||||
|
||||
181
src/vector.c
181
src/vector.c
@@ -23,6 +23,9 @@
|
||||
#define TYPALIGN_INT 'i'
|
||||
#endif
|
||||
|
||||
#define STATE_DIMS(x) (ARR_DIMS(x)[0] - 1)
|
||||
#define CreateStateDatums(dim) palloc(sizeof(Datum) * (dim + 1))
|
||||
|
||||
PG_MODULE_MAGIC;
|
||||
|
||||
/*
|
||||
@@ -82,6 +85,20 @@ CheckElement(float value)
|
||||
errmsg("infinite value not allowed in vector")));
|
||||
}
|
||||
|
||||
/*
|
||||
* Check state array
|
||||
*/
|
||||
static float8 *
|
||||
CheckStateArray(ArrayType *statearray, const char *caller)
|
||||
{
|
||||
if (ARR_NDIM(statearray) != 1 ||
|
||||
ARR_DIMS(statearray)[0] < 1 ||
|
||||
ARR_HASNULL(statearray) ||
|
||||
ARR_ELEMTYPE(statearray) != FLOAT8OID)
|
||||
elog(ERROR, "%s: expected state array", caller);
|
||||
return (float8 *) ARR_DATA_PTR(statearray);
|
||||
}
|
||||
|
||||
/*
|
||||
* Print vector - useful for debugging
|
||||
*/
|
||||
@@ -758,3 +775,167 @@ vector_cmp(PG_FUNCTION_ARGS)
|
||||
|
||||
PG_RETURN_INT32(vector_cmp_internal(a, b));
|
||||
}
|
||||
|
||||
/*
|
||||
* Accumulate vectors
|
||||
*/
|
||||
PGDLLEXPORT PG_FUNCTION_INFO_V1(vector_accum);
|
||||
Datum
|
||||
vector_accum(PG_FUNCTION_ARGS)
|
||||
{
|
||||
ArrayType *statearray = PG_GETARG_ARRAYTYPE_P(0);
|
||||
Vector *newval = PG_GETARG_VECTOR_P(1);
|
||||
float8 *statevalues;
|
||||
int16 dim;
|
||||
bool newarr;
|
||||
float8 n;
|
||||
Datum *statedatums;
|
||||
float *x = newval->x;
|
||||
ArrayType *result;
|
||||
|
||||
/* Check array before using */
|
||||
statevalues = CheckStateArray(statearray, "vector_accum");
|
||||
dim = STATE_DIMS(statearray);
|
||||
newarr = dim == 0;
|
||||
|
||||
if (newarr)
|
||||
dim = newval->dim;
|
||||
else
|
||||
CheckExpectedDim(dim, newval->dim);
|
||||
|
||||
n = statevalues[0] + 1.0;
|
||||
|
||||
statedatums = CreateStateDatums(dim);
|
||||
statedatums[0] = Float8GetDatumFast(n);
|
||||
|
||||
if (newarr)
|
||||
{
|
||||
for (int i = 0; i < dim; i++)
|
||||
statedatums[i + 1] = Float8GetDatumFast(x[i]);
|
||||
}
|
||||
else
|
||||
{
|
||||
for (int i = 0; i < dim; i++)
|
||||
{
|
||||
double v = statevalues[i + 1] + x[i];
|
||||
|
||||
if (isinf(v))
|
||||
float_overflow_error();
|
||||
|
||||
statedatums[i + 1] = Float8GetDatumFast(v);
|
||||
}
|
||||
}
|
||||
|
||||
/* Use float8 array like float4_accum */
|
||||
result = construct_array(statedatums, dim + 1,
|
||||
FLOAT8OID,
|
||||
sizeof(float8), FLOAT8PASSBYVAL, TYPALIGN_DOUBLE);
|
||||
|
||||
pfree(statedatums);
|
||||
|
||||
PG_RETURN_ARRAYTYPE_P(result);
|
||||
}
|
||||
|
||||
/*
|
||||
* Combine vectors
|
||||
*/
|
||||
PGDLLEXPORT PG_FUNCTION_INFO_V1(vector_combine);
|
||||
Datum
|
||||
vector_combine(PG_FUNCTION_ARGS)
|
||||
{
|
||||
ArrayType *statearray1 = PG_GETARG_ARRAYTYPE_P(0);
|
||||
ArrayType *statearray2 = PG_GETARG_ARRAYTYPE_P(1);
|
||||
float8 *statevalues1;
|
||||
float8 *statevalues2;
|
||||
float8 n;
|
||||
float8 n1;
|
||||
float8 n2;
|
||||
int16 dim;
|
||||
Datum *statedatums;
|
||||
ArrayType *result;
|
||||
|
||||
/* Check arrays before using */
|
||||
statevalues1 = CheckStateArray(statearray1, "vector_combine");
|
||||
statevalues2 = CheckStateArray(statearray2, "vector_combine");
|
||||
|
||||
n1 = statevalues1[0];
|
||||
n2 = statevalues2[0];
|
||||
|
||||
if (n1 == 0.0)
|
||||
{
|
||||
n = n2;
|
||||
dim = STATE_DIMS(statearray2);
|
||||
statedatums = CreateStateDatums(dim);
|
||||
for (int i = 1; i <= dim; i++)
|
||||
statedatums[i] = Float8GetDatumFast(statevalues2[i]);
|
||||
}
|
||||
else if (n2 == 0.0)
|
||||
{
|
||||
n = n1;
|
||||
dim = STATE_DIMS(statearray1);
|
||||
statedatums = CreateStateDatums(dim);
|
||||
for (int i = 1; i <= dim; i++)
|
||||
statedatums[i] = Float8GetDatumFast(statevalues1[i]);
|
||||
}
|
||||
else
|
||||
{
|
||||
n = n1 + n2;
|
||||
dim = STATE_DIMS(statearray1);
|
||||
CheckExpectedDim(dim, STATE_DIMS(statearray2));
|
||||
statedatums = CreateStateDatums(dim);
|
||||
for (int i = 1; i <= dim; i++)
|
||||
{
|
||||
double v = statevalues1[i] + statevalues2[i];
|
||||
|
||||
if (isinf(v))
|
||||
float_overflow_error();
|
||||
|
||||
statedatums[i] = Float8GetDatumFast(v);
|
||||
}
|
||||
}
|
||||
|
||||
statedatums[0] = Float8GetDatumFast(n);
|
||||
|
||||
result = construct_array(statedatums, dim + 1,
|
||||
FLOAT8OID,
|
||||
sizeof(float8), FLOAT8PASSBYVAL, TYPALIGN_DOUBLE);
|
||||
|
||||
pfree(statedatums);
|
||||
|
||||
PG_RETURN_ARRAYTYPE_P(result);
|
||||
}
|
||||
|
||||
/*
|
||||
* Average vectors
|
||||
*/
|
||||
PGDLLEXPORT PG_FUNCTION_INFO_V1(vector_avg);
|
||||
Datum
|
||||
vector_avg(PG_FUNCTION_ARGS)
|
||||
{
|
||||
ArrayType *statearray = PG_GETARG_ARRAYTYPE_P(0);
|
||||
float8 *statevalues;
|
||||
float8 n;
|
||||
uint16 dim;
|
||||
Vector *result;
|
||||
float v;
|
||||
|
||||
/* Check array before using */
|
||||
statevalues = CheckStateArray(statearray, "vector_avg");
|
||||
n = statevalues[0];
|
||||
|
||||
/* SQL defines AVG of no values to be NULL */
|
||||
if (n == 0.0)
|
||||
PG_RETURN_NULL();
|
||||
|
||||
/* Create vector */
|
||||
dim = STATE_DIMS(statearray);
|
||||
result = InitVector(dim);
|
||||
for (int i = 0; i < dim; i++)
|
||||
{
|
||||
v = statevalues[i + 1] / n;
|
||||
CheckElement(v);
|
||||
result->x[i] = v;
|
||||
}
|
||||
|
||||
PG_RETURN_POINTER(result);
|
||||
}
|
||||
|
||||
@@ -52,3 +52,23 @@ SELECT cosine_distance('[1,2]', '[0,0]');
|
||||
|
||||
SELECT cosine_distance('[1,2]', '[3]');
|
||||
ERROR: different vector dimensions 2 and 1
|
||||
SELECT avg(v) FROM unnest(ARRAY['[1,2,3]'::vector, '[3,5,7]']) v;
|
||||
avg
|
||||
-----------
|
||||
[2,3.5,5]
|
||||
(1 row)
|
||||
|
||||
SELECT avg(v) FROM unnest(ARRAY['[1,2,3]'::vector, '[3,5,7]', NULL]) v;
|
||||
avg
|
||||
-----------
|
||||
[2,3.5,5]
|
||||
(1 row)
|
||||
|
||||
SELECT avg(v) FROM unnest(ARRAY[]::vector[]) v;
|
||||
avg
|
||||
-----
|
||||
|
||||
(1 row)
|
||||
|
||||
SELECT avg(v) FROM unnest(ARRAY['[1,2]'::vector, '[3]']) v;
|
||||
ERROR: expected 2 dimensions, not 1
|
||||
|
||||
@@ -13,3 +13,8 @@ SELECT inner_product('[1,2]', '[3]');
|
||||
SELECT round(cosine_distance('[1,2]', '[2,4]')::numeric, 5);
|
||||
SELECT cosine_distance('[1,2]', '[0,0]');
|
||||
SELECT cosine_distance('[1,2]', '[3]');
|
||||
|
||||
SELECT avg(v) FROM unnest(ARRAY['[1,2,3]'::vector, '[3,5,7]']) v;
|
||||
SELECT avg(v) FROM unnest(ARRAY['[1,2,3]'::vector, '[3,5,7]', NULL]) v;
|
||||
SELECT avg(v) FROM unnest(ARRAY[]::vector[]) v;
|
||||
SELECT avg(v) FROM unnest(ARRAY['[1,2]'::vector, '[3]']) v;
|
||||
|
||||
27
test/t/008_avg.pl
Normal file
27
test/t/008_avg.pl
Normal file
@@ -0,0 +1,27 @@
|
||||
use strict;
|
||||
use warnings;
|
||||
use PostgresNode;
|
||||
use TestLib;
|
||||
use Test::More tests => 4;
|
||||
|
||||
# Initialize node
|
||||
my $node = get_new_node('node');
|
||||
$node->init;
|
||||
$node->start;
|
||||
|
||||
# Create table
|
||||
$node->safe_psql("postgres", "CREATE EXTENSION vector;");
|
||||
$node->safe_psql("postgres", "CREATE TABLE tst (v vector(3));");
|
||||
$node->safe_psql("postgres",
|
||||
"INSERT INTO tst SELECT ARRAY[1.01 + random(), 2.01 + random(), 3.01 + random()] FROM generate_series(1, 1000000) i;"
|
||||
);
|
||||
|
||||
# Test avg
|
||||
my $avg = $node->safe_psql("postgres", "SELECT AVG(v) FROM tst;");
|
||||
like($avg, qr/\[1\.5/);
|
||||
like($avg, qr/,2\.5/);
|
||||
like($avg, qr/,3\.5/);
|
||||
|
||||
# Test explain
|
||||
my $explain = $node->safe_psql("postgres", "EXPLAIN SELECT AVG(v) FROM tst;");
|
||||
like($explain, qr/Partial Aggregate/);
|
||||
Reference in New Issue
Block a user