From 71ee682ed481fb0254bd8adf26aa913c2e5c2c51 Mon Sep 17 00:00:00 2001 From: Andrew Kane Date: Sun, 17 Mar 2024 19:25:39 -0700 Subject: [PATCH] Added subscript function for vectors --- CHANGELOG.md | 4 + sql/vector--0.6.2--0.7.0.sql | 7 ++ sql/vector.sql | 4 + src/vector.c | 181 +++++++++++++++++++++++++++++++++++ test/expected/functions.out | 106 ++++++++++++++++++++ test/sql/functions.sql | 20 ++++ 6 files changed, 322 insertions(+) create mode 100644 sql/vector--0.6.2--0.7.0.sql diff --git a/CHANGELOG.md b/CHANGELOG.md index bf81eb0..49bcaab 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,7 @@ +## 0.7.0 (unreleased) + +- Added subscript function for vectors + ## 0.6.2 (unreleased) - Reduced lock contention with parallel HNSW index builds diff --git a/sql/vector--0.6.2--0.7.0.sql b/sql/vector--0.6.2--0.7.0.sql new file mode 100644 index 0000000..9d1f565 --- /dev/null +++ b/sql/vector--0.6.2--0.7.0.sql @@ -0,0 +1,7 @@ +-- complain if script is sourced in psql, rather than via CREATE EXTENSION +\echo Use "ALTER EXTENSION vector UPDATE TO '0.7.0'" to load this file. \quit + +CREATE FUNCTION vector_subscript(internal) RETURNS internal + AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + +ALTER TYPE vector SET (SUBSCRIPT = vector_subscript); diff --git a/sql/vector.sql b/sql/vector.sql index 141e83c..195eac6 100644 --- a/sql/vector.sql +++ b/sql/vector.sql @@ -20,12 +20,16 @@ CREATE FUNCTION vector_recv(internal, oid, integer) RETURNS vector CREATE FUNCTION vector_send(vector) RETURNS bytea AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; +CREATE FUNCTION vector_subscript(internal) RETURNS internal + AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + CREATE TYPE vector ( INPUT = vector_in, OUTPUT = vector_out, TYPMOD_IN = vector_typmod_in, RECEIVE = vector_recv, SEND = vector_send, + SUBSCRIPT = vector_subscript, STORAGE = external ); diff --git a/src/vector.c b/src/vector.c index 5f3cbbb..49b5692 100644 --- a/src/vector.c +++ b/src/vector.c @@ -4,11 +4,18 @@ #include "catalog/pg_type.h" #include "common/shortest_dec.h" +#include "executor/execExpr.h" #include "fmgr.h" #include "hnsw.h" #include "ivfflat.h" #include "lib/stringinfo.h" #include "libpq/pqformat.h" +#include "nodes/makefuncs.h" +#include "nodes/nodeFuncs.h" +#include "nodes/subscripting.h" +#include "parser/parse_coerce.h" +#include "parser/parse_expr.h" +#include "parser/parse_node.h" #include "port.h" /* for strtof() */ #include "utils/array.h" #include "utils/builtins.h" @@ -418,6 +425,180 @@ vector_send(PG_FUNCTION_ARGS) PG_RETURN_BYTEA_P(pq_endtypsend(&buf)); } +/* + * Transform the subscript expressions + */ +static void +vector_subscript_transform(SubscriptingRef *sbsref, List *indirection, ParseState *pstate, bool isSlice, bool isAssignment) +{ + A_Indices *ai; + Node *subexpr; + + if (list_length(indirection) != 1) + ereport(ERROR, + (errcode(ERRCODE_FEATURE_NOT_SUPPORTED), + errmsg("vector allows only one subscript"), + parser_errposition(pstate, + exprLocation((Node *) indirection)))); + + ai = linitial_node(A_Indices, indirection); + + if (isSlice) + { + if (ai->lidx) + { + subexpr = transformExpr(pstate, ai->lidx, pstate->p_expr_kind); + /* If it's not int4 already, try to coerce */ + subexpr = coerce_to_target_type(pstate, + subexpr, exprType(subexpr), + INT4OID, -1, + COERCION_ASSIGNMENT, + COERCE_IMPLICIT_CAST, + -1); + if (subexpr == NULL) + ereport(ERROR, + (errcode(ERRCODE_DATATYPE_MISMATCH), + errmsg("vector subscript must have type integer"), + parser_errposition(pstate, exprLocation(ai->lidx)))); + } + else if (!ai->is_slice) + { + /* Make a constant 1 */ + subexpr = (Node *) makeConst(INT4OID, + -1, + InvalidOid, + sizeof(int32), + Int32GetDatum(1), + false, + true); /* pass by value */ + } + else + { + /* Slice with omitted lower bound, put NULL into the list */ + subexpr = NULL; + } + sbsref->reflowerindexpr = list_make1(subexpr); + } + else + Assert(ai->lidx == NULL && !ai->is_slice); + + if (ai->uidx) + { + subexpr = transformExpr(pstate, ai->uidx, pstate->p_expr_kind); + /* If it's not int4 already, try to coerce */ + subexpr = coerce_to_target_type(pstate, + subexpr, exprType(subexpr), + INT4OID, -1, + COERCION_ASSIGNMENT, + COERCE_IMPLICIT_CAST, + -1); + if (subexpr == NULL) + ereport(ERROR, + (errcode(ERRCODE_DATATYPE_MISMATCH), + errmsg("array subscript must have type integer"), + parser_errposition(pstate, exprLocation(ai->uidx)))); + } + else + { + /* Slice with omitted upper bound, put NULL into the list */ + Assert(isSlice && ai->is_slice); + subexpr = NULL; + } + sbsref->refupperindexpr = list_make1(subexpr); + + if (isSlice) + sbsref->refrestype = sbsref->refcontainertype; + else + sbsref->refrestype = FLOAT4OID; +} + +/* + * Fetch a vector element + */ +static void +vector_subscript_fetch(ExprState *state, ExprEvalStep *op, ExprContext *econtext) +{ + SubscriptingRefState *sbsrefstate = op->d.sbsref.state; + Vector *vec = DatumGetVector(*op->resvalue); + int index = DatumGetInt32(sbsrefstate->upperindex[0]); + + if (index < 1 || index > vec->dim) + *op->resnull = true; + else + *op->resvalue = Float4GetDatum(vec->x[index - 1]); +} + +/* + * Fetch a vector slice + */ +static void +vector_subscript_fetch_slice(ExprState *state, ExprEvalStep *op, ExprContext *econtext) +{ + SubscriptingRefState *sbsrefstate = op->d.sbsref.state; + + if (sbsrefstate->upperprovided[0] && sbsrefstate->upperindexnull[0]) + *op->resnull = true; + else if (sbsrefstate->lowerprovided[0] && sbsrefstate->lowerindexnull[0]) + *op->resnull = true; + else + { + Vector *vec = DatumGetVector(*op->resvalue); + int upperindex = sbsrefstate->upperprovided[0] ? DatumGetInt32(sbsrefstate->upperindex[0]) : vec->dim; + int lowerindex = sbsrefstate->lowerprovided[0] ? DatumGetInt32(sbsrefstate->lowerindex[0]) : 1; + int dim; + Vector *result; + + if (upperindex > vec->dim) + upperindex = vec->dim; + + if (lowerindex < 1) + lowerindex = 1; + + dim = upperindex - lowerindex + 1; + + CheckDim(dim); + + result = InitVector(dim); + for (int i = 0; i < dim; i++) + result->x[i] = vec->x[lowerindex + i - 1]; + + *op->resvalue = PointerGetDatum(result); + } +} + +/* + * Set up execution state for a vector subscript operation + */ +static void +vector_exec_setup(const SubscriptingRef *sbsref, SubscriptingRefState *sbsrefstate, SubscriptExecSteps *methods) +{ + methods->sbs_check_subscripts = NULL; + if (sbsrefstate->numlower != 0) + methods->sbs_fetch = vector_subscript_fetch_slice; + else + methods->sbs_fetch = vector_subscript_fetch; + methods->sbs_assign = NULL; + methods->sbs_fetch_old = NULL; +} + +/* + * Subscript handler + */ +PGDLLEXPORT PG_FUNCTION_INFO_V1(vector_subscript); +Datum +vector_subscript(PG_FUNCTION_ARGS) +{ + static const SubscriptRoutines sbsroutines = { + .transform = vector_subscript_transform, + .exec_setup = vector_exec_setup, + .fetch_strict = true, /* fetch returns NULL for NULL inputs */ + .fetch_leakproof = true, /* fetch returns NULL for bad subscript */ + .store_leakproof = false /* ... but assignment throws error */ + }; + + PG_RETURN_POINTER(&sbsroutines); +} + /* * Convert vector to vector * This is needed to check the type modifier diff --git a/test/expected/functions.out b/test/expected/functions.out index 85d1a2f..e065b34 100644 --- a/test/expected/functions.out +++ b/test/expected/functions.out @@ -24,6 +24,112 @@ SELECT '[1e37]'::vector * '[1e37]'; ERROR: value out of range: overflow SELECT '[1e-37]'::vector * '[1e-37]'; ERROR: value out of range: underflow +SELECT ('[1,2,3]'::vector)[0]; + vector +-------- + +(1 row) + +SELECT ('[1,2,3]'::vector)[1]; + vector +-------- + 1 +(1 row) + +SELECT ('[1,2,3]'::vector)[2]; + vector +-------- + 2 +(1 row) + +SELECT ('[1,2,3]'::vector)[3]; + vector +-------- + 3 +(1 row) + +SELECT ('[1,2,3]'::vector)[4]; + vector +-------- + +(1 row) + +SELECT ('[1,2,3]'::vector)[1:1]; + vector +-------- + [1] +(1 row) + +SELECT ('[1,2,3]'::vector)[1:2]; + vector +-------- + [1,2] +(1 row) + +SELECT ('[1,2,3]'::vector)[2:4]; + vector +-------- + [2,3] +(1 row) + +SELECT ('[1,2,3]'::vector)[-2:2]; + vector +-------- + [1,2] +(1 row) + +SELECT ('[1,2,3]'::vector)[2:1]; +ERROR: vector must have at least 1 dimension +SELECT ('[1,2,3]'::vector)[:]; + vector +--------- + [1,2,3] +(1 row) + +SELECT ('[1,2,3]'::vector)[:2]; + vector +-------- + [1,2] +(1 row) + +SELECT ('[1,2,3]'::vector)[2:]; + vector +-------- + [2,3] +(1 row) + +SELECT ('[1,2,3]'::vector)[:4]; + vector +--------- + [1,2,3] +(1 row) + +SELECT ('[1,2,3]'::vector)[-2:]; + vector +--------- + [1,2,3] +(1 row) + +SELECT ('[1,2,3]'::vector)[NULL]; + vector +-------- + +(1 row) + +SELECT ('[1,2,3]'::vector)[NULL:2]; + vector +-------- + +(1 row) + +SELECT ('[1,2,3]'::vector)[2:NULL]; + vector +-------- + +(1 row) + +SELECT ('[1,2,3]'::vector)[1][1]; +ERROR: vector allows only one subscript SELECT '[1,2,3]'::vector = '[1,2,3]'; ?column? ---------- diff --git a/test/sql/functions.sql b/test/sql/functions.sql index 6235684..04bb99e 100644 --- a/test/sql/functions.sql +++ b/test/sql/functions.sql @@ -6,6 +6,26 @@ SELECT '[1,2,3]'::vector * '[4,5,6]'; SELECT '[1e37]'::vector * '[1e37]'; SELECT '[1e-37]'::vector * '[1e-37]'; +SELECT ('[1,2,3]'::vector)[0]; +SELECT ('[1,2,3]'::vector)[1]; +SELECT ('[1,2,3]'::vector)[2]; +SELECT ('[1,2,3]'::vector)[3]; +SELECT ('[1,2,3]'::vector)[4]; +SELECT ('[1,2,3]'::vector)[1:1]; +SELECT ('[1,2,3]'::vector)[1:2]; +SELECT ('[1,2,3]'::vector)[2:4]; +SELECT ('[1,2,3]'::vector)[-2:2]; +SELECT ('[1,2,3]'::vector)[2:1]; +SELECT ('[1,2,3]'::vector)[:]; +SELECT ('[1,2,3]'::vector)[:2]; +SELECT ('[1,2,3]'::vector)[2:]; +SELECT ('[1,2,3]'::vector)[:4]; +SELECT ('[1,2,3]'::vector)[-2:]; +SELECT ('[1,2,3]'::vector)[NULL]; +SELECT ('[1,2,3]'::vector)[NULL:2]; +SELECT ('[1,2,3]'::vector)[2:NULL]; +SELECT ('[1,2,3]'::vector)[1][1]; + SELECT '[1,2,3]'::vector = '[1,2,3]'; SELECT '[1,2,3]'::vector = '[1,2]';