mirror of
https://github.com/pgvector/pgvector.git
synced 2026-06-06 05:51:21 +08:00
1262 lines
26 KiB
C
1262 lines
26 KiB
C
#include "postgres.h"
|
|
|
|
#include <limits.h>
|
|
#include <math.h>
|
|
|
|
#include "catalog/pg_type.h"
|
|
#include "common/shortest_dec.h"
|
|
#include "fmgr.h"
|
|
#include "halfutils.h"
|
|
#include "halfvec.h"
|
|
#include "lib/stringinfo.h"
|
|
#include "libpq/pqformat.h"
|
|
#include "sparsevec.h"
|
|
#include "utils/array.h"
|
|
#include "utils/builtins.h"
|
|
#include "utils/float.h"
|
|
#include "utils/fmgrprotos.h"
|
|
#include "utils/lsyscache.h"
|
|
#include "vector.h"
|
|
|
|
#if PG_VERSION_NUM >= 160000
|
|
#include "varatt.h"
|
|
#endif
|
|
|
|
typedef struct SparseInputElement
|
|
{
|
|
int32 index;
|
|
float value;
|
|
} SparseInputElement;
|
|
|
|
/*
|
|
* Ensure same dimensions
|
|
*/
|
|
static inline void
|
|
CheckDims(SparseVector * a, SparseVector * b)
|
|
{
|
|
if (a->dim != b->dim)
|
|
ereport(ERROR,
|
|
(errcode(ERRCODE_DATA_EXCEPTION),
|
|
errmsg("different sparsevec dimensions %d and %d", a->dim, b->dim)));
|
|
}
|
|
|
|
/*
|
|
* Ensure expected dimensions
|
|
*/
|
|
static inline void
|
|
CheckExpectedDim(int32 typmod, int dim)
|
|
{
|
|
if (typmod != -1 && typmod != dim)
|
|
ereport(ERROR,
|
|
(errcode(ERRCODE_DATA_EXCEPTION),
|
|
errmsg("expected %d dimensions, not %d", typmod, dim)));
|
|
}
|
|
|
|
/*
|
|
* Ensure valid dimensions
|
|
*/
|
|
static inline void
|
|
CheckDim(int dim)
|
|
{
|
|
if (dim < 1)
|
|
ereport(ERROR,
|
|
(errcode(ERRCODE_DATA_EXCEPTION),
|
|
errmsg("sparsevec must have at least 1 dimension")));
|
|
|
|
if (dim > SPARSEVEC_MAX_DIM)
|
|
ereport(ERROR,
|
|
(errcode(ERRCODE_PROGRAM_LIMIT_EXCEEDED),
|
|
errmsg("sparsevec cannot have more than %d dimensions", SPARSEVEC_MAX_DIM)));
|
|
}
|
|
|
|
/*
|
|
* Ensure valid nnz
|
|
*/
|
|
static inline void
|
|
CheckNnz(int nnz, int dim)
|
|
{
|
|
if (nnz < 0)
|
|
ereport(ERROR,
|
|
(errcode(ERRCODE_DATA_EXCEPTION),
|
|
errmsg("sparsevec cannot have negative number of elements")));
|
|
|
|
if (nnz > SPARSEVEC_MAX_NNZ)
|
|
ereport(ERROR,
|
|
(errcode(ERRCODE_PROGRAM_LIMIT_EXCEEDED),
|
|
errmsg("sparsevec cannot have more than %d non-zero elements", SPARSEVEC_MAX_NNZ)));
|
|
|
|
if (nnz > dim)
|
|
ereport(ERROR,
|
|
(errcode(ERRCODE_PROGRAM_LIMIT_EXCEEDED),
|
|
errmsg("sparsevec cannot have more elements than dimensions")));
|
|
}
|
|
|
|
/*
|
|
* Ensure valid index
|
|
*/
|
|
static inline void
|
|
CheckIndex(int32 *indices, int i, int dim)
|
|
{
|
|
int32 index = indices[i];
|
|
|
|
if (index < 0 || index >= dim)
|
|
{
|
|
ereport(ERROR,
|
|
(errcode(ERRCODE_DATA_EXCEPTION),
|
|
errmsg("sparsevec index out of bounds")));
|
|
}
|
|
|
|
if (i > 0)
|
|
{
|
|
if (index < indices[i - 1])
|
|
ereport(ERROR,
|
|
(errcode(ERRCODE_DATA_EXCEPTION),
|
|
errmsg("sparsevec indices must be in ascending order")));
|
|
|
|
if (index == indices[i - 1])
|
|
ereport(ERROR,
|
|
(errcode(ERRCODE_DATA_EXCEPTION),
|
|
errmsg("sparsevec indices must not contain duplicates")));
|
|
}
|
|
}
|
|
|
|
/*
|
|
* Ensure finite element
|
|
*/
|
|
static inline void
|
|
CheckElement(float value)
|
|
{
|
|
if (isnan(value))
|
|
ereport(ERROR,
|
|
(errcode(ERRCODE_DATA_EXCEPTION),
|
|
errmsg("NaN not allowed in sparsevec")));
|
|
|
|
if (isinf(value))
|
|
ereport(ERROR,
|
|
(errcode(ERRCODE_DATA_EXCEPTION),
|
|
errmsg("infinite value not allowed in sparsevec")));
|
|
}
|
|
|
|
/*
|
|
* Allocate and initialize a new sparse vector
|
|
*/
|
|
SparseVector *
|
|
InitSparseVector(int dim, int nnz)
|
|
{
|
|
SparseVector *result;
|
|
int size;
|
|
|
|
size = SPARSEVEC_SIZE(nnz);
|
|
result = (SparseVector *) palloc0(size);
|
|
SET_VARSIZE(result, size);
|
|
result->dim = dim;
|
|
result->nnz = nnz;
|
|
|
|
return result;
|
|
}
|
|
|
|
/*
|
|
* Check for whitespace, since array_isspace() is static
|
|
*/
|
|
static inline bool
|
|
sparsevec_isspace(char ch)
|
|
{
|
|
if (ch == ' ' ||
|
|
ch == '\t' ||
|
|
ch == '\n' ||
|
|
ch == '\r' ||
|
|
ch == '\v' ||
|
|
ch == '\f')
|
|
return true;
|
|
return false;
|
|
}
|
|
|
|
/*
|
|
* Compare indices
|
|
*/
|
|
static int
|
|
CompareIndices(const void *a, const void *b)
|
|
{
|
|
if (((SparseInputElement *) a)->index < ((SparseInputElement *) b)->index)
|
|
return -1;
|
|
|
|
if (((SparseInputElement *) a)->index > ((SparseInputElement *) b)->index)
|
|
return 1;
|
|
|
|
return 0;
|
|
}
|
|
|
|
/*
|
|
* Convert textual representation to internal representation
|
|
*/
|
|
FUNCTION_PREFIX PG_FUNCTION_INFO_V1(sparsevec_in);
|
|
Datum
|
|
sparsevec_in(PG_FUNCTION_ARGS)
|
|
{
|
|
char *lit = PG_GETARG_CSTRING(0);
|
|
int32 typmod = PG_GETARG_INT32(2);
|
|
long dim;
|
|
char *pt = lit;
|
|
char *stringEnd;
|
|
SparseVector *result;
|
|
float *rvalues;
|
|
SparseInputElement *elements;
|
|
int maxNnz;
|
|
int nnz = 0;
|
|
|
|
maxNnz = 1;
|
|
while (*pt != '\0')
|
|
{
|
|
if (*pt == ',')
|
|
maxNnz++;
|
|
|
|
pt++;
|
|
}
|
|
|
|
if (maxNnz > SPARSEVEC_MAX_NNZ)
|
|
ereport(ERROR,
|
|
(errcode(ERRCODE_PROGRAM_LIMIT_EXCEEDED),
|
|
errmsg("sparsevec cannot have more than %d non-zero elements", SPARSEVEC_MAX_NNZ)));
|
|
|
|
elements = palloc(maxNnz * sizeof(SparseInputElement));
|
|
|
|
pt = lit;
|
|
|
|
while (sparsevec_isspace(*pt))
|
|
pt++;
|
|
|
|
if (*pt != '{')
|
|
ereport(ERROR,
|
|
(errcode(ERRCODE_INVALID_TEXT_REPRESENTATION),
|
|
errmsg("invalid input syntax for type sparsevec: \"%s\"", lit),
|
|
errdetail("Vector contents must start with \"{\".")));
|
|
|
|
pt++;
|
|
|
|
while (sparsevec_isspace(*pt))
|
|
pt++;
|
|
|
|
if (*pt == '}')
|
|
pt++;
|
|
else
|
|
{
|
|
for (;;)
|
|
{
|
|
long index;
|
|
float value;
|
|
|
|
if (nnz == maxNnz)
|
|
ereport(ERROR,
|
|
(errcode(ERRCODE_INVALID_TEXT_REPRESENTATION),
|
|
errmsg("ran out of buffer: \"%s\"", lit)));
|
|
|
|
while (sparsevec_isspace(*pt))
|
|
pt++;
|
|
|
|
/* Check for empty string like float4in */
|
|
if (*pt == '\0')
|
|
ereport(ERROR,
|
|
(errcode(ERRCODE_INVALID_TEXT_REPRESENTATION),
|
|
errmsg("invalid input syntax for type sparsevec: \"%s\"", lit)));
|
|
|
|
/* Use similar logic as int2vectorin */
|
|
index = strtol(pt, &stringEnd, 10);
|
|
|
|
if (stringEnd == pt)
|
|
ereport(ERROR,
|
|
(errcode(ERRCODE_INVALID_TEXT_REPRESENTATION),
|
|
errmsg("invalid input syntax for type sparsevec: \"%s\"", lit)));
|
|
|
|
/* Keep in int range for correct error message later */
|
|
if (index > INT_MAX)
|
|
index = INT_MAX;
|
|
else if (index < INT_MIN + 1)
|
|
index = INT_MIN + 1;
|
|
|
|
pt = stringEnd;
|
|
|
|
while (sparsevec_isspace(*pt))
|
|
pt++;
|
|
|
|
if (*pt != ':')
|
|
ereport(ERROR,
|
|
(errcode(ERRCODE_INVALID_TEXT_REPRESENTATION),
|
|
errmsg("invalid input syntax for type sparsevec: \"%s\"", lit)));
|
|
|
|
pt++;
|
|
|
|
while (sparsevec_isspace(*pt))
|
|
pt++;
|
|
|
|
errno = 0;
|
|
|
|
/* Use strtof like float4in to avoid a double-rounding problem */
|
|
/* Postgres sets LC_NUMERIC to C on startup */
|
|
value = strtof(pt, &stringEnd);
|
|
|
|
if (stringEnd == pt)
|
|
ereport(ERROR,
|
|
(errcode(ERRCODE_INVALID_TEXT_REPRESENTATION),
|
|
errmsg("invalid input syntax for type sparsevec: \"%s\"", lit)));
|
|
|
|
/* Check for range error like float4in */
|
|
if (errno == ERANGE && (value == 0 || isinf(value)))
|
|
ereport(ERROR,
|
|
(errcode(ERRCODE_NUMERIC_VALUE_OUT_OF_RANGE),
|
|
errmsg("\"%s\" is out of range for type sparsevec", pnstrdup(pt, stringEnd - pt))));
|
|
|
|
CheckElement(value);
|
|
|
|
/* Do not store zero values */
|
|
if (value != 0)
|
|
{
|
|
/* Convert 1-based numbering (SQL) to 0-based (C) */
|
|
elements[nnz].index = index - 1;
|
|
elements[nnz].value = value;
|
|
nnz++;
|
|
}
|
|
|
|
pt = stringEnd;
|
|
|
|
while (sparsevec_isspace(*pt))
|
|
pt++;
|
|
|
|
if (*pt == ',')
|
|
pt++;
|
|
else if (*pt == '}')
|
|
{
|
|
pt++;
|
|
break;
|
|
}
|
|
else
|
|
ereport(ERROR,
|
|
(errcode(ERRCODE_INVALID_TEXT_REPRESENTATION),
|
|
errmsg("invalid input syntax for type sparsevec: \"%s\"", lit)));
|
|
}
|
|
}
|
|
|
|
while (sparsevec_isspace(*pt))
|
|
pt++;
|
|
|
|
if (*pt != '/')
|
|
ereport(ERROR,
|
|
(errcode(ERRCODE_INVALID_TEXT_REPRESENTATION),
|
|
errmsg("invalid input syntax for type sparsevec: \"%s\"", lit),
|
|
errdetail("Unexpected end of input.")));
|
|
|
|
pt++;
|
|
|
|
while (sparsevec_isspace(*pt))
|
|
pt++;
|
|
|
|
/* Use similar logic as int2vectorin */
|
|
dim = strtol(pt, &stringEnd, 10);
|
|
|
|
if (stringEnd == pt)
|
|
ereport(ERROR,
|
|
(errcode(ERRCODE_INVALID_TEXT_REPRESENTATION),
|
|
errmsg("invalid input syntax for type sparsevec: \"%s\"", lit)));
|
|
|
|
/* Keep in int range for correct error message later */
|
|
if (dim > INT_MAX)
|
|
dim = INT_MAX;
|
|
else if (dim < INT_MIN)
|
|
dim = INT_MIN;
|
|
|
|
pt = stringEnd;
|
|
|
|
/* Only whitespace is allowed after the closing brace */
|
|
while (sparsevec_isspace(*pt))
|
|
pt++;
|
|
|
|
if (*pt != '\0')
|
|
ereport(ERROR,
|
|
(errcode(ERRCODE_INVALID_TEXT_REPRESENTATION),
|
|
errmsg("invalid input syntax for type sparsevec: \"%s\"", lit),
|
|
errdetail("Junk after closing.")));
|
|
|
|
CheckDim(dim);
|
|
CheckExpectedDim(typmod, dim);
|
|
|
|
qsort(elements, nnz, sizeof(SparseInputElement), CompareIndices);
|
|
|
|
result = InitSparseVector(dim, nnz);
|
|
rvalues = SPARSEVEC_VALUES(result);
|
|
for (int i = 0; i < nnz; i++)
|
|
{
|
|
result->indices[i] = elements[i].index;
|
|
rvalues[i] = elements[i].value;
|
|
|
|
CheckIndex(result->indices, i, dim);
|
|
}
|
|
|
|
PG_RETURN_POINTER(result);
|
|
}
|
|
|
|
#define AppendChar(ptr, c) (*(ptr)++ = (c))
|
|
#define AppendFloat(ptr, f) ((ptr) += float_to_shortest_decimal_bufn((f), (ptr)))
|
|
|
|
#if PG_VERSION_NUM >= 140000
|
|
#define AppendInt(ptr, i) ((ptr) += pg_ltoa((i), (ptr)))
|
|
#else
|
|
#define AppendInt(ptr, i) \
|
|
do { \
|
|
pg_ltoa(i, ptr); \
|
|
while (*ptr != '\0') \
|
|
ptr++; \
|
|
} while (0)
|
|
#endif
|
|
|
|
/*
|
|
* Convert internal representation to textual representation
|
|
*/
|
|
FUNCTION_PREFIX PG_FUNCTION_INFO_V1(sparsevec_out);
|
|
Datum
|
|
sparsevec_out(PG_FUNCTION_ARGS)
|
|
{
|
|
SparseVector *sparsevec = PG_GETARG_SPARSEVEC_P(0);
|
|
float *values = SPARSEVEC_VALUES(sparsevec);
|
|
char *buf;
|
|
char *ptr;
|
|
|
|
/*
|
|
* Need:
|
|
*
|
|
* nnz * 10 bytes for index (positive integer)
|
|
*
|
|
* nnz bytes for :
|
|
*
|
|
* nnz * (FLOAT_SHORTEST_DECIMAL_LEN - 1) bytes for
|
|
* float_to_shortest_decimal_bufn
|
|
*
|
|
* nnz - 1 bytes for ,
|
|
*
|
|
* 10 bytes for dimensions
|
|
*
|
|
* 4 bytes for {, }, /, and \0
|
|
*/
|
|
buf = (char *) palloc((11 + FLOAT_SHORTEST_DECIMAL_LEN) * sparsevec->nnz + 13);
|
|
ptr = buf;
|
|
|
|
AppendChar(ptr, '{');
|
|
|
|
for (int i = 0; i < sparsevec->nnz; i++)
|
|
{
|
|
if (i > 0)
|
|
AppendChar(ptr, ',');
|
|
|
|
/* Convert 0-based numbering (C) to 1-based (SQL) */
|
|
AppendInt(ptr, sparsevec->indices[i] + 1);
|
|
AppendChar(ptr, ':');
|
|
AppendFloat(ptr, values[i]);
|
|
}
|
|
|
|
AppendChar(ptr, '}');
|
|
AppendChar(ptr, '/');
|
|
AppendInt(ptr, sparsevec->dim);
|
|
*ptr = '\0';
|
|
|
|
PG_FREE_IF_COPY(sparsevec, 0);
|
|
PG_RETURN_CSTRING(buf);
|
|
}
|
|
|
|
/*
|
|
* Convert type modifier
|
|
*/
|
|
FUNCTION_PREFIX PG_FUNCTION_INFO_V1(sparsevec_typmod_in);
|
|
Datum
|
|
sparsevec_typmod_in(PG_FUNCTION_ARGS)
|
|
{
|
|
ArrayType *ta = PG_GETARG_ARRAYTYPE_P(0);
|
|
int32 *tl;
|
|
int n;
|
|
|
|
tl = ArrayGetIntegerTypmods(ta, &n);
|
|
|
|
if (n != 1)
|
|
ereport(ERROR,
|
|
(errcode(ERRCODE_INVALID_PARAMETER_VALUE),
|
|
errmsg("invalid type modifier")));
|
|
|
|
if (*tl < 1)
|
|
ereport(ERROR,
|
|
(errcode(ERRCODE_INVALID_PARAMETER_VALUE),
|
|
errmsg("dimensions for type sparsevec must be at least 1")));
|
|
|
|
if (*tl > SPARSEVEC_MAX_DIM)
|
|
ereport(ERROR,
|
|
(errcode(ERRCODE_INVALID_PARAMETER_VALUE),
|
|
errmsg("dimensions for type sparsevec cannot exceed %d", SPARSEVEC_MAX_DIM)));
|
|
|
|
PG_RETURN_INT32(*tl);
|
|
}
|
|
|
|
/*
|
|
* Convert external binary representation to internal representation
|
|
*/
|
|
FUNCTION_PREFIX PG_FUNCTION_INFO_V1(sparsevec_recv);
|
|
Datum
|
|
sparsevec_recv(PG_FUNCTION_ARGS)
|
|
{
|
|
StringInfo buf = (StringInfo) PG_GETARG_POINTER(0);
|
|
int32 typmod = PG_GETARG_INT32(2);
|
|
SparseVector *result;
|
|
int32 dim;
|
|
int32 nnz;
|
|
int32 unused;
|
|
float *values;
|
|
|
|
dim = pq_getmsgint(buf, sizeof(int32));
|
|
nnz = pq_getmsgint(buf, sizeof(int32));
|
|
unused = pq_getmsgint(buf, sizeof(int32));
|
|
|
|
CheckDim(dim);
|
|
CheckNnz(nnz, dim);
|
|
CheckExpectedDim(typmod, dim);
|
|
|
|
if (unused != 0)
|
|
ereport(ERROR,
|
|
(errcode(ERRCODE_DATA_EXCEPTION),
|
|
errmsg("expected unused to be 0, not %d", unused)));
|
|
|
|
result = InitSparseVector(dim, nnz);
|
|
values = SPARSEVEC_VALUES(result);
|
|
|
|
/* Binary representation uses zero-based numbering for indices */
|
|
for (int i = 0; i < nnz; i++)
|
|
{
|
|
result->indices[i] = pq_getmsgint(buf, sizeof(int32));
|
|
CheckIndex(result->indices, i, dim);
|
|
}
|
|
|
|
for (int i = 0; i < nnz; i++)
|
|
{
|
|
values[i] = pq_getmsgfloat4(buf);
|
|
CheckElement(values[i]);
|
|
|
|
if (values[i] == 0)
|
|
ereport(ERROR,
|
|
(errcode(ERRCODE_DATA_EXCEPTION),
|
|
errmsg("binary representation of sparsevec cannot contain zero values")));
|
|
}
|
|
|
|
PG_RETURN_POINTER(result);
|
|
}
|
|
|
|
/*
|
|
* Convert internal representation to the external binary representation
|
|
*/
|
|
FUNCTION_PREFIX PG_FUNCTION_INFO_V1(sparsevec_send);
|
|
Datum
|
|
sparsevec_send(PG_FUNCTION_ARGS)
|
|
{
|
|
SparseVector *svec = PG_GETARG_SPARSEVEC_P(0);
|
|
float *values = SPARSEVEC_VALUES(svec);
|
|
StringInfoData buf;
|
|
|
|
pq_begintypsend(&buf);
|
|
pq_sendint(&buf, svec->dim, sizeof(int32));
|
|
pq_sendint(&buf, svec->nnz, sizeof(int32));
|
|
pq_sendint(&buf, svec->unused, sizeof(int32));
|
|
|
|
/* Binary representation uses zero-based numbering for indices */
|
|
for (int i = 0; i < svec->nnz; i++)
|
|
pq_sendint(&buf, svec->indices[i], sizeof(int32));
|
|
|
|
for (int i = 0; i < svec->nnz; i++)
|
|
pq_sendfloat4(&buf, values[i]);
|
|
|
|
PG_RETURN_BYTEA_P(pq_endtypsend(&buf));
|
|
}
|
|
|
|
/*
|
|
* Convert sparse vector to sparse vector
|
|
* This is needed to check the type modifier
|
|
*/
|
|
FUNCTION_PREFIX PG_FUNCTION_INFO_V1(sparsevec);
|
|
Datum
|
|
sparsevec(PG_FUNCTION_ARGS)
|
|
{
|
|
SparseVector *svec = PG_GETARG_SPARSEVEC_P(0);
|
|
int32 typmod = PG_GETARG_INT32(1);
|
|
|
|
CheckExpectedDim(typmod, svec->dim);
|
|
|
|
PG_RETURN_POINTER(svec);
|
|
}
|
|
|
|
/*
|
|
* Convert dense vector to sparse vector
|
|
*/
|
|
FUNCTION_PREFIX PG_FUNCTION_INFO_V1(vector_to_sparsevec);
|
|
Datum
|
|
vector_to_sparsevec(PG_FUNCTION_ARGS)
|
|
{
|
|
Vector *vec = PG_GETARG_VECTOR_P(0);
|
|
int32 typmod = PG_GETARG_INT32(1);
|
|
SparseVector *result;
|
|
int dim = vec->dim;
|
|
int nnz = 0;
|
|
float *values;
|
|
int j = 0;
|
|
|
|
CheckDim(dim);
|
|
CheckExpectedDim(typmod, dim);
|
|
|
|
for (int i = 0; i < dim; i++)
|
|
{
|
|
if (vec->x[i] != 0)
|
|
nnz++;
|
|
}
|
|
|
|
result = InitSparseVector(dim, nnz);
|
|
values = SPARSEVEC_VALUES(result);
|
|
for (int i = 0; i < dim; i++)
|
|
{
|
|
if (vec->x[i] != 0)
|
|
{
|
|
/* Safety check */
|
|
if (j >= result->nnz)
|
|
elog(ERROR, "safety check failed");
|
|
|
|
result->indices[j] = i;
|
|
values[j] = vec->x[i];
|
|
j++;
|
|
}
|
|
}
|
|
|
|
PG_RETURN_POINTER(result);
|
|
}
|
|
|
|
/*
|
|
* Convert half vector to sparse vector
|
|
*/
|
|
FUNCTION_PREFIX PG_FUNCTION_INFO_V1(halfvec_to_sparsevec);
|
|
Datum
|
|
halfvec_to_sparsevec(PG_FUNCTION_ARGS)
|
|
{
|
|
HalfVector *vec = PG_GETARG_HALFVEC_P(0);
|
|
int32 typmod = PG_GETARG_INT32(1);
|
|
SparseVector *result;
|
|
int dim = vec->dim;
|
|
int nnz = 0;
|
|
float *values;
|
|
int j = 0;
|
|
|
|
CheckDim(dim);
|
|
CheckExpectedDim(typmod, dim);
|
|
|
|
for (int i = 0; i < dim; i++)
|
|
{
|
|
if (!HalfIsZero(vec->x[i]))
|
|
nnz++;
|
|
}
|
|
|
|
result = InitSparseVector(dim, nnz);
|
|
values = SPARSEVEC_VALUES(result);
|
|
for (int i = 0; i < dim; i++)
|
|
{
|
|
if (!HalfIsZero(vec->x[i]))
|
|
{
|
|
/* Safety check */
|
|
if (j >= result->nnz)
|
|
elog(ERROR, "safety check failed");
|
|
|
|
result->indices[j] = i;
|
|
values[j] = HalfToFloat4(vec->x[i]);
|
|
j++;
|
|
}
|
|
}
|
|
|
|
PG_RETURN_POINTER(result);
|
|
}
|
|
|
|
/*
|
|
* Convert array to sparse vector
|
|
*/
|
|
FUNCTION_PREFIX PG_FUNCTION_INFO_V1(array_to_sparsevec);
|
|
Datum
|
|
array_to_sparsevec(PG_FUNCTION_ARGS)
|
|
{
|
|
ArrayType *array = PG_GETARG_ARRAYTYPE_P(0);
|
|
int32 typmod = PG_GETARG_INT32(1);
|
|
SparseVector *result;
|
|
int16 typlen;
|
|
bool typbyval;
|
|
char typalign;
|
|
Datum *elemsp;
|
|
int nelemsp;
|
|
int nnz = 0;
|
|
float *values;
|
|
int j = 0;
|
|
|
|
if (ARR_NDIM(array) > 1)
|
|
ereport(ERROR,
|
|
(errcode(ERRCODE_DATA_EXCEPTION),
|
|
errmsg("array must be 1-D")));
|
|
|
|
if (ARR_HASNULL(array) && array_contains_nulls(array))
|
|
ereport(ERROR,
|
|
(errcode(ERRCODE_NULL_VALUE_NOT_ALLOWED),
|
|
errmsg("array must not contain nulls")));
|
|
|
|
get_typlenbyvalalign(ARR_ELEMTYPE(array), &typlen, &typbyval, &typalign);
|
|
deconstruct_array(array, ARR_ELEMTYPE(array), typlen, typbyval, typalign, &elemsp, NULL, &nelemsp);
|
|
|
|
CheckDim(nelemsp);
|
|
CheckExpectedDim(typmod, nelemsp);
|
|
|
|
#ifdef _MSC_VER
|
|
/* /fp:fast may not propagate +/-Infinity or NaN */
|
|
#define IS_NOT_ZERO(v) (isnan((float) (v)) || isinf((float) (v)) || ((float) (v)) != 0)
|
|
#else
|
|
#define IS_NOT_ZERO(v) (((float) (v)) != 0)
|
|
#endif
|
|
|
|
if (ARR_ELEMTYPE(array) == INT4OID)
|
|
{
|
|
for (int i = 0; i < nelemsp; i++)
|
|
nnz += IS_NOT_ZERO(DatumGetInt32(elemsp[i]));
|
|
}
|
|
else if (ARR_ELEMTYPE(array) == FLOAT8OID)
|
|
{
|
|
for (int i = 0; i < nelemsp; i++)
|
|
nnz += IS_NOT_ZERO(DatumGetFloat8(elemsp[i]));
|
|
}
|
|
else if (ARR_ELEMTYPE(array) == FLOAT4OID)
|
|
{
|
|
for (int i = 0; i < nelemsp; i++)
|
|
nnz += IS_NOT_ZERO(DatumGetFloat4(elemsp[i]));
|
|
}
|
|
else if (ARR_ELEMTYPE(array) == NUMERICOID)
|
|
{
|
|
for (int i = 0; i < nelemsp; i++)
|
|
nnz += IS_NOT_ZERO(DirectFunctionCall1(numeric_float4, elemsp[i]));
|
|
}
|
|
else
|
|
{
|
|
ereport(ERROR,
|
|
(errcode(ERRCODE_DATA_EXCEPTION),
|
|
errmsg("unsupported array type")));
|
|
}
|
|
|
|
result = InitSparseVector(nelemsp, nnz);
|
|
values = SPARSEVEC_VALUES(result);
|
|
|
|
#define PROCESS_ARRAY_ELEM(elem) \
|
|
do { \
|
|
float v = (float) (elem); \
|
|
if (IS_NOT_ZERO(v)) { \
|
|
/* Safety check */ \
|
|
if (j >= result->nnz) \
|
|
elog(ERROR, "safety check failed"); \
|
|
result->indices[j] = i; \
|
|
values[j] = v; \
|
|
j++; \
|
|
} \
|
|
} while (0)
|
|
|
|
if (ARR_ELEMTYPE(array) == INT4OID)
|
|
{
|
|
for (int i = 0; i < nelemsp; i++)
|
|
PROCESS_ARRAY_ELEM(DatumGetInt32(elemsp[i]));
|
|
}
|
|
else if (ARR_ELEMTYPE(array) == FLOAT8OID)
|
|
{
|
|
for (int i = 0; i < nelemsp; i++)
|
|
PROCESS_ARRAY_ELEM(DatumGetFloat8(elemsp[i]));
|
|
}
|
|
else if (ARR_ELEMTYPE(array) == FLOAT4OID)
|
|
{
|
|
for (int i = 0; i < nelemsp; i++)
|
|
PROCESS_ARRAY_ELEM(DatumGetFloat4(elemsp[i]));
|
|
}
|
|
else if (ARR_ELEMTYPE(array) == NUMERICOID)
|
|
{
|
|
for (int i = 0; i < nelemsp; i++)
|
|
PROCESS_ARRAY_ELEM(DatumGetFloat4(DirectFunctionCall1(numeric_float4, elemsp[i])));
|
|
}
|
|
else
|
|
{
|
|
ereport(ERROR,
|
|
(errcode(ERRCODE_DATA_EXCEPTION),
|
|
errmsg("unsupported array type")));
|
|
}
|
|
|
|
#undef PROCESS_ARRAY_ELEM
|
|
#undef IS_NOT_ZERO
|
|
|
|
/*
|
|
* Free allocation from deconstruct_array. Do not free individual elements
|
|
* when pass-by-reference since they point to original array.
|
|
*/
|
|
pfree(elemsp);
|
|
|
|
if (j != result->nnz)
|
|
elog(ERROR, "correctness check failed");
|
|
|
|
/* Check elements */
|
|
for (int i = 0; i < result->nnz; i++)
|
|
CheckElement(values[i]);
|
|
|
|
PG_RETURN_POINTER(result);
|
|
}
|
|
|
|
/*
|
|
* Get the L2 squared distance between sparse vectors
|
|
*/
|
|
static float
|
|
SparsevecL2SquaredDistance(SparseVector * a, SparseVector * b)
|
|
{
|
|
float *ax = SPARSEVEC_VALUES(a);
|
|
float *bx = SPARSEVEC_VALUES(b);
|
|
float distance = 0.0;
|
|
int bpos = 0;
|
|
|
|
for (int i = 0; i < a->nnz; i++)
|
|
{
|
|
int ai = a->indices[i];
|
|
int bi = -1;
|
|
|
|
for (int j = bpos; j < b->nnz; j++)
|
|
{
|
|
bi = b->indices[j];
|
|
|
|
if (ai == bi)
|
|
{
|
|
float diff = ax[i] - bx[j];
|
|
|
|
distance += diff * diff;
|
|
}
|
|
else if (ai > bi)
|
|
distance += bx[j] * bx[j];
|
|
|
|
/* Update start for next iteration */
|
|
if (ai >= bi)
|
|
bpos = j + 1;
|
|
|
|
/* Found or passed it */
|
|
if (bi >= ai)
|
|
break;
|
|
}
|
|
|
|
if (ai != bi)
|
|
distance += ax[i] * ax[i];
|
|
}
|
|
|
|
for (int j = bpos; j < b->nnz; j++)
|
|
distance += bx[j] * bx[j];
|
|
|
|
return distance;
|
|
}
|
|
|
|
/*
|
|
* Get the L2 distance between sparse vectors
|
|
*/
|
|
FUNCTION_PREFIX PG_FUNCTION_INFO_V1(sparsevec_l2_distance);
|
|
Datum
|
|
sparsevec_l2_distance(PG_FUNCTION_ARGS)
|
|
{
|
|
SparseVector *a = PG_GETARG_SPARSEVEC_P(0);
|
|
SparseVector *b = PG_GETARG_SPARSEVEC_P(1);
|
|
|
|
CheckDims(a, b);
|
|
|
|
PG_RETURN_FLOAT8(sqrt((double) SparsevecL2SquaredDistance(a, b)));
|
|
}
|
|
|
|
/*
|
|
* Get the L2 squared distance between sparse vectors
|
|
* This saves a sqrt calculation
|
|
*/
|
|
FUNCTION_PREFIX PG_FUNCTION_INFO_V1(sparsevec_l2_squared_distance);
|
|
Datum
|
|
sparsevec_l2_squared_distance(PG_FUNCTION_ARGS)
|
|
{
|
|
SparseVector *a = PG_GETARG_SPARSEVEC_P(0);
|
|
SparseVector *b = PG_GETARG_SPARSEVEC_P(1);
|
|
|
|
CheckDims(a, b);
|
|
|
|
PG_RETURN_FLOAT8((double) SparsevecL2SquaredDistance(a, b));
|
|
}
|
|
|
|
/*
|
|
* Get the inner product of two sparse vectors
|
|
*/
|
|
static float
|
|
SparsevecInnerProduct(SparseVector * a, SparseVector * b)
|
|
{
|
|
float *ax = SPARSEVEC_VALUES(a);
|
|
float *bx = SPARSEVEC_VALUES(b);
|
|
float distance = 0.0;
|
|
int bpos = 0;
|
|
|
|
for (int i = 0; i < a->nnz; i++)
|
|
{
|
|
int ai = a->indices[i];
|
|
|
|
for (int j = bpos; j < b->nnz; j++)
|
|
{
|
|
int bi = b->indices[j];
|
|
|
|
/* Only update when the same index */
|
|
if (ai == bi)
|
|
distance += ax[i] * bx[j];
|
|
|
|
/* Update start for next iteration */
|
|
if (ai >= bi)
|
|
bpos = j + 1;
|
|
|
|
/* Found or passed it */
|
|
if (bi >= ai)
|
|
break;
|
|
}
|
|
}
|
|
|
|
return distance;
|
|
}
|
|
|
|
/*
|
|
* Get the inner product of two sparse vectors
|
|
*/
|
|
FUNCTION_PREFIX PG_FUNCTION_INFO_V1(sparsevec_inner_product);
|
|
Datum
|
|
sparsevec_inner_product(PG_FUNCTION_ARGS)
|
|
{
|
|
SparseVector *a = PG_GETARG_SPARSEVEC_P(0);
|
|
SparseVector *b = PG_GETARG_SPARSEVEC_P(1);
|
|
|
|
CheckDims(a, b);
|
|
|
|
PG_RETURN_FLOAT8((double) SparsevecInnerProduct(a, b));
|
|
}
|
|
|
|
/*
|
|
* Get the negative inner product of two sparse vectors
|
|
*/
|
|
FUNCTION_PREFIX PG_FUNCTION_INFO_V1(sparsevec_negative_inner_product);
|
|
Datum
|
|
sparsevec_negative_inner_product(PG_FUNCTION_ARGS)
|
|
{
|
|
SparseVector *a = PG_GETARG_SPARSEVEC_P(0);
|
|
SparseVector *b = PG_GETARG_SPARSEVEC_P(1);
|
|
|
|
CheckDims(a, b);
|
|
|
|
PG_RETURN_FLOAT8((double) -SparsevecInnerProduct(a, b));
|
|
}
|
|
|
|
/*
|
|
* Get the cosine distance between two sparse vectors
|
|
*/
|
|
FUNCTION_PREFIX PG_FUNCTION_INFO_V1(sparsevec_cosine_distance);
|
|
Datum
|
|
sparsevec_cosine_distance(PG_FUNCTION_ARGS)
|
|
{
|
|
SparseVector *a = PG_GETARG_SPARSEVEC_P(0);
|
|
SparseVector *b = PG_GETARG_SPARSEVEC_P(1);
|
|
float *ax = SPARSEVEC_VALUES(a);
|
|
float *bx = SPARSEVEC_VALUES(b);
|
|
float norma = 0.0;
|
|
float normb = 0.0;
|
|
double similarity;
|
|
|
|
CheckDims(a, b);
|
|
|
|
similarity = SparsevecInnerProduct(a, b);
|
|
|
|
/* Auto-vectorized */
|
|
for (int i = 0; i < a->nnz; i++)
|
|
norma += ax[i] * ax[i];
|
|
|
|
/* Auto-vectorized */
|
|
for (int i = 0; i < b->nnz; i++)
|
|
normb += bx[i] * bx[i];
|
|
|
|
/* Use sqrt(a * b) over sqrt(a) * sqrt(b) */
|
|
similarity /= sqrt((double) norma * (double) normb);
|
|
|
|
#ifdef _MSC_VER
|
|
/* /fp:fast may not propagate NaN */
|
|
if (isnan(similarity))
|
|
PG_RETURN_FLOAT8(NAN);
|
|
#endif
|
|
|
|
/* Keep in range */
|
|
if (similarity > 1)
|
|
similarity = 1.0;
|
|
else if (similarity < -1)
|
|
similarity = -1.0;
|
|
|
|
PG_RETURN_FLOAT8(1.0 - similarity);
|
|
}
|
|
|
|
/*
|
|
* Get the L1 distance between two sparse vectors
|
|
*/
|
|
FUNCTION_PREFIX PG_FUNCTION_INFO_V1(sparsevec_l1_distance);
|
|
Datum
|
|
sparsevec_l1_distance(PG_FUNCTION_ARGS)
|
|
{
|
|
SparseVector *a = PG_GETARG_SPARSEVEC_P(0);
|
|
SparseVector *b = PG_GETARG_SPARSEVEC_P(1);
|
|
float *ax = SPARSEVEC_VALUES(a);
|
|
float *bx = SPARSEVEC_VALUES(b);
|
|
float distance = 0.0;
|
|
int bpos = 0;
|
|
|
|
CheckDims(a, b);
|
|
|
|
for (int i = 0; i < a->nnz; i++)
|
|
{
|
|
int ai = a->indices[i];
|
|
int bi = -1;
|
|
|
|
for (int j = bpos; j < b->nnz; j++)
|
|
{
|
|
bi = b->indices[j];
|
|
|
|
if (ai == bi)
|
|
distance += fabsf(ax[i] - bx[j]);
|
|
else if (ai > bi)
|
|
distance += fabsf(bx[j]);
|
|
|
|
/* Update start for next iteration */
|
|
if (ai >= bi)
|
|
bpos = j + 1;
|
|
|
|
/* Found or passed it */
|
|
if (bi >= ai)
|
|
break;
|
|
}
|
|
|
|
if (ai != bi)
|
|
distance += fabsf(ax[i]);
|
|
}
|
|
|
|
for (int j = bpos; j < b->nnz; j++)
|
|
distance += fabsf(bx[j]);
|
|
|
|
PG_RETURN_FLOAT8((double) distance);
|
|
}
|
|
|
|
/*
|
|
* Get the L2 norm of a sparse vector
|
|
*/
|
|
FUNCTION_PREFIX PG_FUNCTION_INFO_V1(sparsevec_l2_norm);
|
|
Datum
|
|
sparsevec_l2_norm(PG_FUNCTION_ARGS)
|
|
{
|
|
SparseVector *a = PG_GETARG_SPARSEVEC_P(0);
|
|
float *ax = SPARSEVEC_VALUES(a);
|
|
double norm = 0.0;
|
|
|
|
/* Auto-vectorized */
|
|
for (int i = 0; i < a->nnz; i++)
|
|
norm += (double) ax[i] * (double) ax[i];
|
|
|
|
PG_RETURN_FLOAT8(sqrt(norm));
|
|
}
|
|
|
|
/*
|
|
* Normalize a sparse vector with the L2 norm
|
|
*/
|
|
FUNCTION_PREFIX PG_FUNCTION_INFO_V1(sparsevec_l2_normalize);
|
|
Datum
|
|
sparsevec_l2_normalize(PG_FUNCTION_ARGS)
|
|
{
|
|
SparseVector *a = PG_GETARG_SPARSEVEC_P(0);
|
|
float *ax = SPARSEVEC_VALUES(a);
|
|
double norm = 0;
|
|
SparseVector *result;
|
|
float *rx;
|
|
|
|
result = InitSparseVector(a->dim, a->nnz);
|
|
rx = SPARSEVEC_VALUES(result);
|
|
|
|
/* Auto-vectorized */
|
|
for (int i = 0; i < a->nnz; i++)
|
|
norm += (double) ax[i] * (double) ax[i];
|
|
|
|
norm = sqrt(norm);
|
|
|
|
/* Return zero vector for zero norm */
|
|
if (norm > 0)
|
|
{
|
|
int zeros = 0;
|
|
|
|
for (int i = 0; i < a->nnz; i++)
|
|
{
|
|
result->indices[i] = a->indices[i];
|
|
rx[i] = ax[i] / norm;
|
|
|
|
if (isinf(rx[i]))
|
|
float_overflow_error();
|
|
|
|
if (rx[i] == 0)
|
|
zeros++;
|
|
}
|
|
|
|
/* Allocate a new vector in the unlikely event there are zeros */
|
|
if (zeros > 0)
|
|
{
|
|
SparseVector *newResult = InitSparseVector(result->dim, result->nnz - zeros);
|
|
float *nx = SPARSEVEC_VALUES(newResult);
|
|
int j = 0;
|
|
|
|
for (int i = 0; i < result->nnz; i++)
|
|
{
|
|
if (rx[i] == 0)
|
|
continue;
|
|
|
|
/* Safety check */
|
|
if (j >= newResult->nnz)
|
|
elog(ERROR, "safety check failed");
|
|
|
|
newResult->indices[j] = result->indices[i];
|
|
nx[j] = rx[i];
|
|
j++;
|
|
}
|
|
|
|
pfree(result);
|
|
|
|
PG_RETURN_POINTER(newResult);
|
|
}
|
|
}
|
|
|
|
PG_RETURN_POINTER(result);
|
|
}
|
|
|
|
/*
|
|
* Internal helper to compare sparse vectors
|
|
*/
|
|
static int
|
|
sparsevec_cmp_internal(SparseVector * a, SparseVector * b)
|
|
{
|
|
float *ax = SPARSEVEC_VALUES(a);
|
|
float *bx = SPARSEVEC_VALUES(b);
|
|
int nnz = Min(a->nnz, b->nnz);
|
|
|
|
/* Check values before dimensions to be consistent with Postgres arrays */
|
|
for (int i = 0; i < nnz; i++)
|
|
{
|
|
if (a->indices[i] < b->indices[i])
|
|
return ax[i] < 0 ? -1 : 1;
|
|
|
|
if (a->indices[i] > b->indices[i])
|
|
return bx[i] < 0 ? 1 : -1;
|
|
|
|
if (ax[i] < bx[i])
|
|
return -1;
|
|
|
|
if (ax[i] > bx[i])
|
|
return 1;
|
|
}
|
|
|
|
if (a->nnz < b->nnz && b->indices[nnz] < a->dim)
|
|
return bx[nnz] < 0 ? 1 : -1;
|
|
|
|
if (a->nnz > b->nnz && a->indices[nnz] < b->dim)
|
|
return ax[nnz] < 0 ? -1 : 1;
|
|
|
|
if (a->dim < b->dim)
|
|
return -1;
|
|
|
|
if (a->dim > b->dim)
|
|
return 1;
|
|
|
|
return 0;
|
|
}
|
|
|
|
/*
|
|
* Less than
|
|
*/
|
|
FUNCTION_PREFIX PG_FUNCTION_INFO_V1(sparsevec_lt);
|
|
Datum
|
|
sparsevec_lt(PG_FUNCTION_ARGS)
|
|
{
|
|
SparseVector *a = PG_GETARG_SPARSEVEC_P(0);
|
|
SparseVector *b = PG_GETARG_SPARSEVEC_P(1);
|
|
|
|
PG_RETURN_BOOL(sparsevec_cmp_internal(a, b) < 0);
|
|
}
|
|
|
|
/*
|
|
* Less than or equal
|
|
*/
|
|
FUNCTION_PREFIX PG_FUNCTION_INFO_V1(sparsevec_le);
|
|
Datum
|
|
sparsevec_le(PG_FUNCTION_ARGS)
|
|
{
|
|
SparseVector *a = PG_GETARG_SPARSEVEC_P(0);
|
|
SparseVector *b = PG_GETARG_SPARSEVEC_P(1);
|
|
|
|
PG_RETURN_BOOL(sparsevec_cmp_internal(a, b) <= 0);
|
|
}
|
|
|
|
/*
|
|
* Equal
|
|
*/
|
|
FUNCTION_PREFIX PG_FUNCTION_INFO_V1(sparsevec_eq);
|
|
Datum
|
|
sparsevec_eq(PG_FUNCTION_ARGS)
|
|
{
|
|
SparseVector *a = PG_GETARG_SPARSEVEC_P(0);
|
|
SparseVector *b = PG_GETARG_SPARSEVEC_P(1);
|
|
|
|
PG_RETURN_BOOL(sparsevec_cmp_internal(a, b) == 0);
|
|
}
|
|
|
|
/*
|
|
* Not equal
|
|
*/
|
|
FUNCTION_PREFIX PG_FUNCTION_INFO_V1(sparsevec_ne);
|
|
Datum
|
|
sparsevec_ne(PG_FUNCTION_ARGS)
|
|
{
|
|
SparseVector *a = PG_GETARG_SPARSEVEC_P(0);
|
|
SparseVector *b = PG_GETARG_SPARSEVEC_P(1);
|
|
|
|
PG_RETURN_BOOL(sparsevec_cmp_internal(a, b) != 0);
|
|
}
|
|
|
|
/*
|
|
* Greater than or equal
|
|
*/
|
|
FUNCTION_PREFIX PG_FUNCTION_INFO_V1(sparsevec_ge);
|
|
Datum
|
|
sparsevec_ge(PG_FUNCTION_ARGS)
|
|
{
|
|
SparseVector *a = PG_GETARG_SPARSEVEC_P(0);
|
|
SparseVector *b = PG_GETARG_SPARSEVEC_P(1);
|
|
|
|
PG_RETURN_BOOL(sparsevec_cmp_internal(a, b) >= 0);
|
|
}
|
|
|
|
/*
|
|
* Greater than
|
|
*/
|
|
FUNCTION_PREFIX PG_FUNCTION_INFO_V1(sparsevec_gt);
|
|
Datum
|
|
sparsevec_gt(PG_FUNCTION_ARGS)
|
|
{
|
|
SparseVector *a = PG_GETARG_SPARSEVEC_P(0);
|
|
SparseVector *b = PG_GETARG_SPARSEVEC_P(1);
|
|
|
|
PG_RETURN_BOOL(sparsevec_cmp_internal(a, b) > 0);
|
|
}
|
|
|
|
/*
|
|
* Compare sparse vectors
|
|
*/
|
|
FUNCTION_PREFIX PG_FUNCTION_INFO_V1(sparsevec_cmp);
|
|
Datum
|
|
sparsevec_cmp(PG_FUNCTION_ARGS)
|
|
{
|
|
SparseVector *a = PG_GETARG_SPARSEVEC_P(0);
|
|
SparseVector *b = PG_GETARG_SPARSEVEC_P(1);
|
|
|
|
PG_RETURN_INT32(sparsevec_cmp_internal(a, b));
|
|
}
|