Added subscript function for vectors

This commit is contained in:
Andrew Kane
2024-03-17 19:25:39 -07:00
parent b64a1482d9
commit 71ee682ed4
6 changed files with 322 additions and 0 deletions

View File

@@ -1,3 +1,7 @@
## 0.7.0 (unreleased)
- Added subscript function for vectors
## 0.6.2 (unreleased)
- Reduced lock contention with parallel HNSW index builds

View File

@@ -0,0 +1,7 @@
-- complain if script is sourced in psql, rather than via CREATE EXTENSION
\echo Use "ALTER EXTENSION vector UPDATE TO '0.7.0'" to load this file. \quit
CREATE FUNCTION vector_subscript(internal) RETURNS internal
AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE;
ALTER TYPE vector SET (SUBSCRIPT = vector_subscript);

View File

@@ -20,12 +20,16 @@ CREATE FUNCTION vector_recv(internal, oid, integer) RETURNS vector
CREATE FUNCTION vector_send(vector) RETURNS bytea
AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE;
CREATE FUNCTION vector_subscript(internal) RETURNS internal
AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE;
CREATE TYPE vector (
INPUT = vector_in,
OUTPUT = vector_out,
TYPMOD_IN = vector_typmod_in,
RECEIVE = vector_recv,
SEND = vector_send,
SUBSCRIPT = vector_subscript,
STORAGE = external
);

View File

@@ -4,11 +4,18 @@
#include "catalog/pg_type.h"
#include "common/shortest_dec.h"
#include "executor/execExpr.h"
#include "fmgr.h"
#include "hnsw.h"
#include "ivfflat.h"
#include "lib/stringinfo.h"
#include "libpq/pqformat.h"
#include "nodes/makefuncs.h"
#include "nodes/nodeFuncs.h"
#include "nodes/subscripting.h"
#include "parser/parse_coerce.h"
#include "parser/parse_expr.h"
#include "parser/parse_node.h"
#include "port.h" /* for strtof() */
#include "utils/array.h"
#include "utils/builtins.h"
@@ -418,6 +425,180 @@ vector_send(PG_FUNCTION_ARGS)
PG_RETURN_BYTEA_P(pq_endtypsend(&buf));
}
/*
* Transform the subscript expressions
*/
static void
vector_subscript_transform(SubscriptingRef *sbsref, List *indirection, ParseState *pstate, bool isSlice, bool isAssignment)
{
A_Indices *ai;
Node *subexpr;
if (list_length(indirection) != 1)
ereport(ERROR,
(errcode(ERRCODE_FEATURE_NOT_SUPPORTED),
errmsg("vector allows only one subscript"),
parser_errposition(pstate,
exprLocation((Node *) indirection))));
ai = linitial_node(A_Indices, indirection);
if (isSlice)
{
if (ai->lidx)
{
subexpr = transformExpr(pstate, ai->lidx, pstate->p_expr_kind);
/* If it's not int4 already, try to coerce */
subexpr = coerce_to_target_type(pstate,
subexpr, exprType(subexpr),
INT4OID, -1,
COERCION_ASSIGNMENT,
COERCE_IMPLICIT_CAST,
-1);
if (subexpr == NULL)
ereport(ERROR,
(errcode(ERRCODE_DATATYPE_MISMATCH),
errmsg("vector subscript must have type integer"),
parser_errposition(pstate, exprLocation(ai->lidx))));
}
else if (!ai->is_slice)
{
/* Make a constant 1 */
subexpr = (Node *) makeConst(INT4OID,
-1,
InvalidOid,
sizeof(int32),
Int32GetDatum(1),
false,
true); /* pass by value */
}
else
{
/* Slice with omitted lower bound, put NULL into the list */
subexpr = NULL;
}
sbsref->reflowerindexpr = list_make1(subexpr);
}
else
Assert(ai->lidx == NULL && !ai->is_slice);
if (ai->uidx)
{
subexpr = transformExpr(pstate, ai->uidx, pstate->p_expr_kind);
/* If it's not int4 already, try to coerce */
subexpr = coerce_to_target_type(pstate,
subexpr, exprType(subexpr),
INT4OID, -1,
COERCION_ASSIGNMENT,
COERCE_IMPLICIT_CAST,
-1);
if (subexpr == NULL)
ereport(ERROR,
(errcode(ERRCODE_DATATYPE_MISMATCH),
errmsg("array subscript must have type integer"),
parser_errposition(pstate, exprLocation(ai->uidx))));
}
else
{
/* Slice with omitted upper bound, put NULL into the list */
Assert(isSlice && ai->is_slice);
subexpr = NULL;
}
sbsref->refupperindexpr = list_make1(subexpr);
if (isSlice)
sbsref->refrestype = sbsref->refcontainertype;
else
sbsref->refrestype = FLOAT4OID;
}
/*
* Fetch a vector element
*/
static void
vector_subscript_fetch(ExprState *state, ExprEvalStep *op, ExprContext *econtext)
{
SubscriptingRefState *sbsrefstate = op->d.sbsref.state;
Vector *vec = DatumGetVector(*op->resvalue);
int index = DatumGetInt32(sbsrefstate->upperindex[0]);
if (index < 1 || index > vec->dim)
*op->resnull = true;
else
*op->resvalue = Float4GetDatum(vec->x[index - 1]);
}
/*
* Fetch a vector slice
*/
static void
vector_subscript_fetch_slice(ExprState *state, ExprEvalStep *op, ExprContext *econtext)
{
SubscriptingRefState *sbsrefstate = op->d.sbsref.state;
if (sbsrefstate->upperprovided[0] && sbsrefstate->upperindexnull[0])
*op->resnull = true;
else if (sbsrefstate->lowerprovided[0] && sbsrefstate->lowerindexnull[0])
*op->resnull = true;
else
{
Vector *vec = DatumGetVector(*op->resvalue);
int upperindex = sbsrefstate->upperprovided[0] ? DatumGetInt32(sbsrefstate->upperindex[0]) : vec->dim;
int lowerindex = sbsrefstate->lowerprovided[0] ? DatumGetInt32(sbsrefstate->lowerindex[0]) : 1;
int dim;
Vector *result;
if (upperindex > vec->dim)
upperindex = vec->dim;
if (lowerindex < 1)
lowerindex = 1;
dim = upperindex - lowerindex + 1;
CheckDim(dim);
result = InitVector(dim);
for (int i = 0; i < dim; i++)
result->x[i] = vec->x[lowerindex + i - 1];
*op->resvalue = PointerGetDatum(result);
}
}
/*
* Set up execution state for a vector subscript operation
*/
static void
vector_exec_setup(const SubscriptingRef *sbsref, SubscriptingRefState *sbsrefstate, SubscriptExecSteps *methods)
{
methods->sbs_check_subscripts = NULL;
if (sbsrefstate->numlower != 0)
methods->sbs_fetch = vector_subscript_fetch_slice;
else
methods->sbs_fetch = vector_subscript_fetch;
methods->sbs_assign = NULL;
methods->sbs_fetch_old = NULL;
}
/*
* Subscript handler
*/
PGDLLEXPORT PG_FUNCTION_INFO_V1(vector_subscript);
Datum
vector_subscript(PG_FUNCTION_ARGS)
{
static const SubscriptRoutines sbsroutines = {
.transform = vector_subscript_transform,
.exec_setup = vector_exec_setup,
.fetch_strict = true, /* fetch returns NULL for NULL inputs */
.fetch_leakproof = true, /* fetch returns NULL for bad subscript */
.store_leakproof = false /* ... but assignment throws error */
};
PG_RETURN_POINTER(&sbsroutines);
}
/*
* Convert vector to vector
* This is needed to check the type modifier

View File

@@ -24,6 +24,112 @@ SELECT '[1e37]'::vector * '[1e37]';
ERROR: value out of range: overflow
SELECT '[1e-37]'::vector * '[1e-37]';
ERROR: value out of range: underflow
SELECT ('[1,2,3]'::vector)[0];
vector
--------
(1 row)
SELECT ('[1,2,3]'::vector)[1];
vector
--------
1
(1 row)
SELECT ('[1,2,3]'::vector)[2];
vector
--------
2
(1 row)
SELECT ('[1,2,3]'::vector)[3];
vector
--------
3
(1 row)
SELECT ('[1,2,3]'::vector)[4];
vector
--------
(1 row)
SELECT ('[1,2,3]'::vector)[1:1];
vector
--------
[1]
(1 row)
SELECT ('[1,2,3]'::vector)[1:2];
vector
--------
[1,2]
(1 row)
SELECT ('[1,2,3]'::vector)[2:4];
vector
--------
[2,3]
(1 row)
SELECT ('[1,2,3]'::vector)[-2:2];
vector
--------
[1,2]
(1 row)
SELECT ('[1,2,3]'::vector)[2:1];
ERROR: vector must have at least 1 dimension
SELECT ('[1,2,3]'::vector)[:];
vector
---------
[1,2,3]
(1 row)
SELECT ('[1,2,3]'::vector)[:2];
vector
--------
[1,2]
(1 row)
SELECT ('[1,2,3]'::vector)[2:];
vector
--------
[2,3]
(1 row)
SELECT ('[1,2,3]'::vector)[:4];
vector
---------
[1,2,3]
(1 row)
SELECT ('[1,2,3]'::vector)[-2:];
vector
---------
[1,2,3]
(1 row)
SELECT ('[1,2,3]'::vector)[NULL];
vector
--------
(1 row)
SELECT ('[1,2,3]'::vector)[NULL:2];
vector
--------
(1 row)
SELECT ('[1,2,3]'::vector)[2:NULL];
vector
--------
(1 row)
SELECT ('[1,2,3]'::vector)[1][1];
ERROR: vector allows only one subscript
SELECT '[1,2,3]'::vector = '[1,2,3]';
?column?
----------

View File

@@ -6,6 +6,26 @@ SELECT '[1,2,3]'::vector * '[4,5,6]';
SELECT '[1e37]'::vector * '[1e37]';
SELECT '[1e-37]'::vector * '[1e-37]';
SELECT ('[1,2,3]'::vector)[0];
SELECT ('[1,2,3]'::vector)[1];
SELECT ('[1,2,3]'::vector)[2];
SELECT ('[1,2,3]'::vector)[3];
SELECT ('[1,2,3]'::vector)[4];
SELECT ('[1,2,3]'::vector)[1:1];
SELECT ('[1,2,3]'::vector)[1:2];
SELECT ('[1,2,3]'::vector)[2:4];
SELECT ('[1,2,3]'::vector)[-2:2];
SELECT ('[1,2,3]'::vector)[2:1];
SELECT ('[1,2,3]'::vector)[:];
SELECT ('[1,2,3]'::vector)[:2];
SELECT ('[1,2,3]'::vector)[2:];
SELECT ('[1,2,3]'::vector)[:4];
SELECT ('[1,2,3]'::vector)[-2:];
SELECT ('[1,2,3]'::vector)[NULL];
SELECT ('[1,2,3]'::vector)[NULL:2];
SELECT ('[1,2,3]'::vector)[2:NULL];
SELECT ('[1,2,3]'::vector)[1][1];
SELECT '[1,2,3]'::vector = '[1,2,3]';
SELECT '[1,2,3]'::vector = '[1,2]';