From 3fb05eb847cd384d37d1fd2fa66a54296ab3f66f Mon Sep 17 00:00:00 2001 From: Andrew Kane Date: Thu, 19 Sep 2024 19:17:05 -0700 Subject: [PATCH] Added casts for arrays to sparsevec - #604 Co-authored-by: Narek Galstyan Co-authored-by: Di Qi --- CHANGELOG.md | 1 + sql/vector--0.7.4--0.8.0.sql | 26 ++++++++ sql/vector.sql | 24 +++++++ src/sparsevec.c | 122 +++++++++++++++++++++++++++++++++++ test/expected/cast.out | 56 ++++++++++++++++ test/sql/cast.sql | 16 +++++ 6 files changed, 245 insertions(+) create mode 100644 sql/vector--0.7.4--0.8.0.sql diff --git a/CHANGELOG.md b/CHANGELOG.md index 55d4ad8..db6798c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,6 @@ ## 0.8.0 (unreleased) +- Added casts for arrays to `sparsevec` - Reduced memory usage for HNSW index scans - Dropped support for Postgres 12 diff --git a/sql/vector--0.7.4--0.8.0.sql b/sql/vector--0.7.4--0.8.0.sql new file mode 100644 index 0000000..e00348d --- /dev/null +++ b/sql/vector--0.7.4--0.8.0.sql @@ -0,0 +1,26 @@ +-- complain if script is sourced in psql, rather than via CREATE EXTENSION +\echo Use "ALTER EXTENSION vector UPDATE TO '0.8.0'" to load this file. \quit + +CREATE FUNCTION array_to_sparsevec(integer[], integer, boolean) RETURNS sparsevec + AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + +CREATE FUNCTION array_to_sparsevec(real[], integer, boolean) RETURNS sparsevec + AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + +CREATE FUNCTION array_to_sparsevec(double precision[], integer, boolean) RETURNS sparsevec + AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + +CREATE FUNCTION array_to_sparsevec(numeric[], integer, boolean) RETURNS sparsevec + AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + +CREATE CAST (integer[] AS sparsevec) + WITH FUNCTION array_to_sparsevec(integer[], integer, boolean) AS ASSIGNMENT; + +CREATE CAST (real[] AS sparsevec) + WITH FUNCTION array_to_sparsevec(real[], integer, boolean) AS ASSIGNMENT; + +CREATE CAST (double precision[] AS sparsevec) + WITH FUNCTION array_to_sparsevec(double precision[], integer, boolean) AS ASSIGNMENT; + +CREATE CAST (numeric[] AS sparsevec) + WITH FUNCTION array_to_sparsevec(numeric[], integer, boolean) AS ASSIGNMENT; diff --git a/sql/vector.sql b/sql/vector.sql index 32eb834..7fc3671 100644 --- a/sql/vector.sql +++ b/sql/vector.sql @@ -782,6 +782,18 @@ CREATE FUNCTION halfvec_to_sparsevec(halfvec, integer, boolean) RETURNS sparseve CREATE FUNCTION sparsevec_to_halfvec(sparsevec, integer, boolean) RETURNS halfvec AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; +CREATE FUNCTION array_to_sparsevec(integer[], integer, boolean) RETURNS sparsevec + AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + +CREATE FUNCTION array_to_sparsevec(real[], integer, boolean) RETURNS sparsevec + AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + +CREATE FUNCTION array_to_sparsevec(double precision[], integer, boolean) RETURNS sparsevec + AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + +CREATE FUNCTION array_to_sparsevec(numeric[], integer, boolean) RETURNS sparsevec + AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + -- sparsevec casts CREATE CAST (sparsevec AS sparsevec) @@ -799,6 +811,18 @@ CREATE CAST (sparsevec AS halfvec) CREATE CAST (halfvec AS sparsevec) WITH FUNCTION halfvec_to_sparsevec(halfvec, integer, boolean) AS IMPLICIT; +CREATE CAST (integer[] AS sparsevec) + WITH FUNCTION array_to_sparsevec(integer[], integer, boolean) AS ASSIGNMENT; + +CREATE CAST (real[] AS sparsevec) + WITH FUNCTION array_to_sparsevec(real[], integer, boolean) AS ASSIGNMENT; + +CREATE CAST (double precision[] AS sparsevec) + WITH FUNCTION array_to_sparsevec(double precision[], integer, boolean) AS ASSIGNMENT; + +CREATE CAST (numeric[] AS sparsevec) + WITH FUNCTION array_to_sparsevec(numeric[], integer, boolean) AS ASSIGNMENT; + -- sparsevec operators CREATE OPERATOR <-> ( diff --git a/src/sparsevec.c b/src/sparsevec.c index 4211fd8..55c14c4 100644 --- a/src/sparsevec.c +++ b/src/sparsevec.c @@ -3,6 +3,7 @@ #include #include +#include "catalog/pg_type.h" #include "common/string.h" #include "fmgr.h" #include "halfutils.h" @@ -11,6 +12,7 @@ #include "sparsevec.h" #include "utils/array.h" #include "utils/builtins.h" +#include "utils/lsyscache.h" #include "vector.h" #if PG_VERSION_NUM >= 120000 @@ -670,6 +672,126 @@ halfvec_to_sparsevec(PG_FUNCTION_ARGS) PG_RETURN_POINTER(result); } +/* + * Convert array to sparse vector + */ +FUNCTION_PREFIX PG_FUNCTION_INFO_V1(array_to_sparsevec); +Datum +array_to_sparsevec(PG_FUNCTION_ARGS) +{ + ArrayType *array = PG_GETARG_ARRAYTYPE_P(0); + int32 typmod = PG_GETARG_INT32(1); + SparseVector *result; + int16 typlen; + bool typbyval; + char typalign; + Datum *elemsp; + int nelemsp; + int nnz = 0; + float *values; + int j = 0; + + if (ARR_NDIM(array) > 1) + ereport(ERROR, + (errcode(ERRCODE_DATA_EXCEPTION), + errmsg("array must be 1-D"))); + + if (ARR_HASNULL(array) && array_contains_nulls(array)) + ereport(ERROR, + (errcode(ERRCODE_NULL_VALUE_NOT_ALLOWED), + errmsg("array must not contain nulls"))); + + get_typlenbyvalalign(ARR_ELEMTYPE(array), &typlen, &typbyval, &typalign); + deconstruct_array(array, ARR_ELEMTYPE(array), typlen, typbyval, typalign, &elemsp, NULL, &nelemsp); + + CheckDim(nelemsp); + CheckExpectedDim(typmod, nelemsp); + + if (ARR_ELEMTYPE(array) == INT4OID) + { + for (int i = 0; i < nelemsp; i++) + nnz += ((float) DatumGetInt32(elemsp[i])) != 0; + } + else if (ARR_ELEMTYPE(array) == FLOAT8OID) + { + for (int i = 0; i < nelemsp; i++) + nnz += ((float) DatumGetFloat8(elemsp[i])) != 0; + } + else if (ARR_ELEMTYPE(array) == FLOAT4OID) + { + for (int i = 0; i < nelemsp; i++) + nnz += (DatumGetFloat4(elemsp[i]) != 0); + } + else if (ARR_ELEMTYPE(array) == NUMERICOID) + { + for (int i = 0; i < nelemsp; i++) + nnz += (DatumGetFloat4(DirectFunctionCall1(numeric_float4, elemsp[i])) != 0); + } + else + { + ereport(ERROR, + (errcode(ERRCODE_DATA_EXCEPTION), + errmsg("unsupported array type"))); + } + + result = InitSparseVector(nelemsp, nnz); + values = SPARSEVEC_VALUES(result); + +#define PROCESS_ARRAY_ELEM(elem) \ + do { \ + float v = (float) (elem); \ + if (v != 0) { \ + /* Safety check */ \ + if (j >= result->nnz) \ + elog(ERROR, "safety check failed"); \ + result->indices[j] = i; \ + values[j] = v; \ + j++; \ + } \ + } while (0) + + if (ARR_ELEMTYPE(array) == INT4OID) + { + for (int i = 0; i < nelemsp; i++) + PROCESS_ARRAY_ELEM(DatumGetInt32(elemsp[i])); + } + else if (ARR_ELEMTYPE(array) == FLOAT8OID) + { + for (int i = 0; i < nelemsp; i++) + PROCESS_ARRAY_ELEM(DatumGetFloat8(elemsp[i])); + } + else if (ARR_ELEMTYPE(array) == FLOAT4OID) + { + for (int i = 0; i < nelemsp; i++) + PROCESS_ARRAY_ELEM(DatumGetFloat4(elemsp[i])); + } + else if (ARR_ELEMTYPE(array) == NUMERICOID) + { + for (int i = 0; i < nelemsp; i++) + PROCESS_ARRAY_ELEM(DatumGetFloat4(DirectFunctionCall1(numeric_float4, elemsp[i]))); + } + else + { + ereport(ERROR, + (errcode(ERRCODE_DATA_EXCEPTION), + errmsg("unsupported array type"))); + } + +#undef PROCESS_ARRAY_ELEM + + /* + * Free allocation from deconstruct_array. Do not free individual elements + * when pass-by-reference since they point to original array. + */ + pfree(elemsp); + + /* Check elements */ + for (int i = 0; i < result->nnz; i++) + CheckElement(values[i]); + + PG_RETURN_POINTER(result); +} + /* * Get the L2 squared distance between sparse vectors */ diff --git a/test/expected/cast.out b/test/expected/cast.out index 1aba43c..c180fe6 100644 --- a/test/expected/cast.out +++ b/test/expected/cast.out @@ -208,6 +208,62 @@ SELECT '{1:1e-8}/1'::sparsevec::halfvec; [0] (1 row) +SELECT ARRAY[1,0,2,0,3,0]::sparsevec; + array +----------------- + {1:1,3:2,5:3}/6 +(1 row) + +SELECT ARRAY[1.0,0.0,2.0,0.0,3.0,0.0]::sparsevec; + array +----------------- + {1:1,3:2,5:3}/6 +(1 row) + +SELECT ARRAY[1,0,2,0,3,0]::float4[]::sparsevec; + array +----------------- + {1:1,3:2,5:3}/6 +(1 row) + +SELECT ARRAY[1,0,2,0,3,0]::float8[]::sparsevec; + array +----------------- + {1:1,3:2,5:3}/6 +(1 row) + +SELECT ARRAY[1,0,2,0,3,0]::numeric[]::sparsevec; + array +----------------- + {1:1,3:2,5:3}/6 +(1 row) + +SELECT '{1,0,2,0,3,0}'::real[]::sparsevec; + sparsevec +----------------- + {1:1,3:2,5:3}/6 +(1 row) + +SELECT '{1,0,2,0,3,0}'::real[]::sparsevec(6); + sparsevec +----------------- + {1:1,3:2,5:3}/6 +(1 row) + +SELECT '{1,0,2,0,3,0}'::real[]::sparsevec(5); +ERROR: expected 5 dimensions, not 6 +SELECT '{NULL}'::real[]::sparsevec; +ERROR: array must not contain nulls +SELECT '{NaN}'::real[]::sparsevec; +ERROR: NaN not allowed in sparsevec +SELECT '{Infinity}'::real[]::sparsevec; +ERROR: infinite value not allowed in sparsevec +SELECT '{-Infinity}'::real[]::sparsevec; +ERROR: infinite value not allowed in sparsevec +SELECT '{}'::real[]::sparsevec; +ERROR: sparsevec must have at least 1 dimension +SELECT '{{1}}'::real[]::sparsevec; +ERROR: array must be 1-D SELECT array_agg(n)::vector FROM generate_series(1, 16001) n; ERROR: vector cannot have more than 16000 dimensions SELECT array_to_vector(array_agg(n), 16001, false) FROM generate_series(1, 16001) n; diff --git a/test/sql/cast.sql b/test/sql/cast.sql index cd2eb3c..fe83931 100644 --- a/test/sql/cast.sql +++ b/test/sql/cast.sql @@ -58,6 +58,22 @@ SELECT '{}/16001'::sparsevec::halfvec; SELECT '{1:65520}/1'::sparsevec::halfvec; SELECT '{1:1e-8}/1'::sparsevec::halfvec; +SELECT ARRAY[1,0,2,0,3,0]::sparsevec; +SELECT ARRAY[1.0,0.0,2.0,0.0,3.0,0.0]::sparsevec; +SELECT ARRAY[1,0,2,0,3,0]::float4[]::sparsevec; +SELECT ARRAY[1,0,2,0,3,0]::float8[]::sparsevec; +SELECT ARRAY[1,0,2,0,3,0]::numeric[]::sparsevec; + +SELECT '{1,0,2,0,3,0}'::real[]::sparsevec; +SELECT '{1,0,2,0,3,0}'::real[]::sparsevec(6); +SELECT '{1,0,2,0,3,0}'::real[]::sparsevec(5); +SELECT '{NULL}'::real[]::sparsevec; +SELECT '{NaN}'::real[]::sparsevec; +SELECT '{Infinity}'::real[]::sparsevec; +SELECT '{-Infinity}'::real[]::sparsevec; +SELECT '{}'::real[]::sparsevec; +SELECT '{{1}}'::real[]::sparsevec; + SELECT array_agg(n)::vector FROM generate_series(1, 16001) n; SELECT array_to_vector(array_agg(n), 16001, false) FROM generate_series(1, 16001) n;