mirror of
https://github.com/pgvector/pgvector.git
synced 2026-06-30 01:31:15 +08:00
Added casts for arrays to sparsevec - #604
Co-authored-by: Narek Galstyan <narekg@berkeley.edu> Co-authored-by: Di Qi <di@lantern.dev>
This commit is contained in:
122
src/sparsevec.c
122
src/sparsevec.c
@@ -3,6 +3,7 @@
|
||||
#include <limits.h>
|
||||
#include <math.h>
|
||||
|
||||
#include "catalog/pg_type.h"
|
||||
#include "common/string.h"
|
||||
#include "fmgr.h"
|
||||
#include "halfutils.h"
|
||||
@@ -11,6 +12,7 @@
|
||||
#include "sparsevec.h"
|
||||
#include "utils/array.h"
|
||||
#include "utils/builtins.h"
|
||||
#include "utils/lsyscache.h"
|
||||
#include "vector.h"
|
||||
|
||||
#if PG_VERSION_NUM >= 120000
|
||||
@@ -670,6 +672,126 @@ halfvec_to_sparsevec(PG_FUNCTION_ARGS)
|
||||
PG_RETURN_POINTER(result);
|
||||
}
|
||||
|
||||
/*
|
||||
* Convert array to sparse vector
|
||||
*/
|
||||
FUNCTION_PREFIX PG_FUNCTION_INFO_V1(array_to_sparsevec);
|
||||
Datum
|
||||
array_to_sparsevec(PG_FUNCTION_ARGS)
|
||||
{
|
||||
ArrayType *array = PG_GETARG_ARRAYTYPE_P(0);
|
||||
int32 typmod = PG_GETARG_INT32(1);
|
||||
SparseVector *result;
|
||||
int16 typlen;
|
||||
bool typbyval;
|
||||
char typalign;
|
||||
Datum *elemsp;
|
||||
int nelemsp;
|
||||
int nnz = 0;
|
||||
float *values;
|
||||
int j = 0;
|
||||
|
||||
if (ARR_NDIM(array) > 1)
|
||||
ereport(ERROR,
|
||||
(errcode(ERRCODE_DATA_EXCEPTION),
|
||||
errmsg("array must be 1-D")));
|
||||
|
||||
if (ARR_HASNULL(array) && array_contains_nulls(array))
|
||||
ereport(ERROR,
|
||||
(errcode(ERRCODE_NULL_VALUE_NOT_ALLOWED),
|
||||
errmsg("array must not contain nulls")));
|
||||
|
||||
get_typlenbyvalalign(ARR_ELEMTYPE(array), &typlen, &typbyval, &typalign);
|
||||
deconstruct_array(array, ARR_ELEMTYPE(array), typlen, typbyval, typalign, &elemsp, NULL, &nelemsp);
|
||||
|
||||
CheckDim(nelemsp);
|
||||
CheckExpectedDim(typmod, nelemsp);
|
||||
|
||||
if (ARR_ELEMTYPE(array) == INT4OID)
|
||||
{
|
||||
for (int i = 0; i < nelemsp; i++)
|
||||
nnz += ((float) DatumGetInt32(elemsp[i])) != 0;
|
||||
}
|
||||
else if (ARR_ELEMTYPE(array) == FLOAT8OID)
|
||||
{
|
||||
for (int i = 0; i < nelemsp; i++)
|
||||
nnz += ((float) DatumGetFloat8(elemsp[i])) != 0;
|
||||
}
|
||||
else if (ARR_ELEMTYPE(array) == FLOAT4OID)
|
||||
{
|
||||
for (int i = 0; i < nelemsp; i++)
|
||||
nnz += (DatumGetFloat4(elemsp[i]) != 0);
|
||||
}
|
||||
else if (ARR_ELEMTYPE(array) == NUMERICOID)
|
||||
{
|
||||
for (int i = 0; i < nelemsp; i++)
|
||||
nnz += (DatumGetFloat4(DirectFunctionCall1(numeric_float4, elemsp[i])) != 0);
|
||||
}
|
||||
else
|
||||
{
|
||||
ereport(ERROR,
|
||||
(errcode(ERRCODE_DATA_EXCEPTION),
|
||||
errmsg("unsupported array type")));
|
||||
}
|
||||
|
||||
result = InitSparseVector(nelemsp, nnz);
|
||||
values = SPARSEVEC_VALUES(result);
|
||||
|
||||
#define PROCESS_ARRAY_ELEM(elem) \
|
||||
do { \
|
||||
float v = (float) (elem); \
|
||||
if (v != 0) { \
|
||||
/* Safety check */ \
|
||||
if (j >= result->nnz) \
|
||||
elog(ERROR, "safety check failed"); \
|
||||
result->indices[j] = i; \
|
||||
values[j] = v; \
|
||||
j++; \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
if (ARR_ELEMTYPE(array) == INT4OID)
|
||||
{
|
||||
for (int i = 0; i < nelemsp; i++)
|
||||
PROCESS_ARRAY_ELEM(DatumGetInt32(elemsp[i]));
|
||||
}
|
||||
else if (ARR_ELEMTYPE(array) == FLOAT8OID)
|
||||
{
|
||||
for (int i = 0; i < nelemsp; i++)
|
||||
PROCESS_ARRAY_ELEM(DatumGetFloat8(elemsp[i]));
|
||||
}
|
||||
else if (ARR_ELEMTYPE(array) == FLOAT4OID)
|
||||
{
|
||||
for (int i = 0; i < nelemsp; i++)
|
||||
PROCESS_ARRAY_ELEM(DatumGetFloat4(elemsp[i]));
|
||||
}
|
||||
else if (ARR_ELEMTYPE(array) == NUMERICOID)
|
||||
{
|
||||
for (int i = 0; i < nelemsp; i++)
|
||||
PROCESS_ARRAY_ELEM(DatumGetFloat4(DirectFunctionCall1(numeric_float4, elemsp[i])));
|
||||
}
|
||||
else
|
||||
{
|
||||
ereport(ERROR,
|
||||
(errcode(ERRCODE_DATA_EXCEPTION),
|
||||
errmsg("unsupported array type")));
|
||||
}
|
||||
|
||||
#undef PROCESS_ARRAY_ELEM
|
||||
|
||||
/*
|
||||
* Free allocation from deconstruct_array. Do not free individual elements
|
||||
* when pass-by-reference since they point to original array.
|
||||
*/
|
||||
pfree(elemsp);
|
||||
|
||||
/* Check elements */
|
||||
for (int i = 0; i < result->nnz; i++)
|
||||
CheckElement(values[i]);
|
||||
|
||||
PG_RETURN_POINTER(result);
|
||||
}
|
||||
|
||||
/*
|
||||
* Get the L2 squared distance between sparse vectors
|
||||
*/
|
||||
|
||||
Reference in New Issue
Block a user