From 0da6213a60bc71aeccc24d1c4c28cb744beb6fd2 Mon Sep 17 00:00:00 2001 From: Andrew Kane Date: Tue, 23 Apr 2024 13:02:47 -0700 Subject: [PATCH] Moved type lookup to support functions - #527 --- sql/vector--0.6.2--0.7.0.sql | 57 +++++++++++++++++++++++++--------- sql/vector.sql | 59 +++++++++++++++++++++++++++--------- src/hnsw.c | 2 +- src/hnsw.h | 4 ++- src/hnswutils.c | 52 ++++++++++++++++++------------- src/ivfflat.c | 2 +- src/ivfflat.h | 4 ++- src/ivfutils.c | 43 ++++++++++++++------------ 8 files changed, 151 insertions(+), 72 deletions(-) diff --git a/sql/vector--0.6.2--0.7.0.sql b/sql/vector--0.6.2--0.7.0.sql index 8e135e6..846d5a5 100644 --- a/sql/vector--0.6.2--0.7.0.sql +++ b/sql/vector--0.6.2--0.7.0.sql @@ -33,6 +33,12 @@ CREATE FUNCTION hamming_distance(bit, bit) RETURNS float8 CREATE FUNCTION jaccard_distance(bit, bit) RETURNS float8 AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; +CREATE FUNCTION bit_ivfflat_support(internal) RETURNS internal + AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + +CREATE FUNCTION bit_hnsw_support(internal) RETURNS internal + AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + CREATE OPERATOR <~> ( LEFTARG = bit, RIGHTARG = bit, PROCEDURE = hamming_distance, COMMUTATOR = '<~>' @@ -47,17 +53,20 @@ CREATE OPERATOR CLASS bit_hamming_ops FOR TYPE bit USING ivfflat AS OPERATOR 1 <~> (bit, bit) FOR ORDER BY float_ops, FUNCTION 1 hamming_distance(bit, bit), - FUNCTION 3 hamming_distance(bit, bit); + FUNCTION 3 hamming_distance(bit, bit), + FUNCTION 6 bit_ivfflat_support(internal); CREATE OPERATOR CLASS bit_hamming_ops FOR TYPE bit USING hnsw AS OPERATOR 1 <~> (bit, bit) FOR ORDER BY float_ops, - FUNCTION 1 hamming_distance(bit, bit); + FUNCTION 1 hamming_distance(bit, bit), + FUNCTION 4 bit_hnsw_support(internal); CREATE OPERATOR CLASS bit_jaccard_ops FOR TYPE bit USING hnsw AS OPERATOR 1 <%> (bit, bit) FOR ORDER BY float_ops, - FUNCTION 1 jaccard_distance(bit, bit); + FUNCTION 1 jaccard_distance(bit, bit), + FUNCTION 4 bit_hnsw_support(internal); CREATE TYPE halfvec; @@ -160,6 +169,12 @@ CREATE FUNCTION halfvec_accum(double precision[], halfvec) RETURNS double precis CREATE FUNCTION halfvec_avg(double precision[]) RETURNS halfvec AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; +CREATE FUNCTION halfvec_ivfflat_support(internal) RETURNS internal + AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + +CREATE FUNCTION halfvec_hnsw_support(internal) RETURNS internal + AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + CREATE AGGREGATE avg(halfvec) ( SFUNC = halfvec_accum, STYPE = double precision[], @@ -311,7 +326,8 @@ CREATE OPERATOR CLASS halfvec_l2_ops FOR TYPE halfvec USING ivfflat AS OPERATOR 1 <-> (halfvec, halfvec) FOR ORDER BY float_ops, FUNCTION 1 halfvec_l2_squared_distance(halfvec, halfvec), - FUNCTION 3 l2_distance(halfvec, halfvec); + FUNCTION 3 l2_distance(halfvec, halfvec), + FUNCTION 6 halfvec_ivfflat_support(internal); CREATE OPERATOR CLASS halfvec_ip_ops FOR TYPE halfvec USING ivfflat AS @@ -319,7 +335,8 @@ CREATE OPERATOR CLASS halfvec_ip_ops FUNCTION 1 halfvec_negative_inner_product(halfvec, halfvec), FUNCTION 3 halfvec_spherical_distance(halfvec, halfvec), FUNCTION 4 l2_norm(halfvec), - FUNCTION 5 l2_normalize(halfvec); + FUNCTION 5 l2_normalize(halfvec), + FUNCTION 6 halfvec_ivfflat_support(internal); CREATE OPERATOR CLASS halfvec_cosine_ops FOR TYPE halfvec USING ivfflat AS @@ -328,29 +345,34 @@ CREATE OPERATOR CLASS halfvec_cosine_ops FUNCTION 2 l2_norm(halfvec), FUNCTION 3 halfvec_spherical_distance(halfvec, halfvec), FUNCTION 4 l2_norm(halfvec), - FUNCTION 5 l2_normalize(halfvec); + FUNCTION 5 l2_normalize(halfvec), + FUNCTION 6 halfvec_ivfflat_support(internal); CREATE OPERATOR CLASS halfvec_l2_ops FOR TYPE halfvec USING hnsw AS OPERATOR 1 <-> (halfvec, halfvec) FOR ORDER BY float_ops, - FUNCTION 1 halfvec_l2_squared_distance(halfvec, halfvec); + FUNCTION 1 halfvec_l2_squared_distance(halfvec, halfvec), + FUNCTION 4 halfvec_hnsw_support(internal); CREATE OPERATOR CLASS halfvec_ip_ops FOR TYPE halfvec USING hnsw AS OPERATOR 1 <#> (halfvec, halfvec) FOR ORDER BY float_ops, - FUNCTION 1 halfvec_negative_inner_product(halfvec, halfvec); + FUNCTION 1 halfvec_negative_inner_product(halfvec, halfvec), + FUNCTION 4 halfvec_hnsw_support(internal); CREATE OPERATOR CLASS halfvec_cosine_ops FOR TYPE halfvec USING hnsw AS OPERATOR 1 <=> (halfvec, halfvec) FOR ORDER BY float_ops, FUNCTION 1 halfvec_negative_inner_product(halfvec, halfvec), FUNCTION 2 l2_norm(halfvec), - FUNCTION 3 l2_normalize(halfvec); + FUNCTION 3 l2_normalize(halfvec), + FUNCTION 4 halfvec_hnsw_support(internal); CREATE OPERATOR CLASS halfvec_l1_ops FOR TYPE halfvec USING hnsw AS OPERATOR 1 <+> (halfvec, halfvec) FOR ORDER BY float_ops, - FUNCTION 1 l1_distance(halfvec, halfvec); + FUNCTION 1 l1_distance(halfvec, halfvec), + FUNCTION 4 halfvec_hnsw_support(internal); CREATE TYPE sparsevec; @@ -438,6 +460,9 @@ CREATE FUNCTION halfvec_to_sparsevec(halfvec, integer, boolean) RETURNS sparseve CREATE FUNCTION sparsevec_to_halfvec(sparsevec, integer, boolean) RETURNS halfvec AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; +CREATE FUNCTION sparsevec_hnsw_support(internal) RETURNS internal + AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + CREATE CAST (sparsevec AS sparsevec) WITH FUNCTION sparsevec(sparsevec, integer, boolean) AS IMPLICIT; @@ -521,21 +546,25 @@ CREATE OPERATOR CLASS sparsevec_ops CREATE OPERATOR CLASS sparsevec_l2_ops FOR TYPE sparsevec USING hnsw AS OPERATOR 1 <-> (sparsevec, sparsevec) FOR ORDER BY float_ops, - FUNCTION 1 sparsevec_l2_squared_distance(sparsevec, sparsevec); + FUNCTION 1 sparsevec_l2_squared_distance(sparsevec, sparsevec), + FUNCTION 4 sparsevec_hnsw_support(internal); CREATE OPERATOR CLASS sparsevec_ip_ops FOR TYPE sparsevec USING hnsw AS OPERATOR 1 <#> (sparsevec, sparsevec) FOR ORDER BY float_ops, - FUNCTION 1 sparsevec_negative_inner_product(sparsevec, sparsevec); + FUNCTION 1 sparsevec_negative_inner_product(sparsevec, sparsevec), + FUNCTION 4 sparsevec_hnsw_support(internal); CREATE OPERATOR CLASS sparsevec_cosine_ops FOR TYPE sparsevec USING hnsw AS OPERATOR 1 <=> (sparsevec, sparsevec) FOR ORDER BY float_ops, FUNCTION 1 sparsevec_negative_inner_product(sparsevec, sparsevec), FUNCTION 2 l2_norm(sparsevec), - FUNCTION 3 l2_normalize(sparsevec); + FUNCTION 3 l2_normalize(sparsevec), + FUNCTION 4 sparsevec_hnsw_support(internal); CREATE OPERATOR CLASS sparsevec_l1_ops FOR TYPE sparsevec USING hnsw AS OPERATOR 1 <+> (sparsevec, sparsevec) FOR ORDER BY float_ops, - FUNCTION 1 l1_distance(sparsevec, sparsevec); + FUNCTION 1 l1_distance(sparsevec, sparsevec), + FUNCTION 4 sparsevec_hnsw_support(internal); diff --git a/sql/vector.sql b/sql/vector.sql index a43871d..da09d1c 100644 --- a/sql/vector.sql +++ b/sql/vector.sql @@ -322,6 +322,14 @@ CREATE FUNCTION hamming_distance(bit, bit) RETURNS float8 CREATE FUNCTION jaccard_distance(bit, bit) RETURNS float8 AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; +-- bit private functions + +CREATE FUNCTION bit_ivfflat_support(internal) RETURNS internal + AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + +CREATE FUNCTION bit_hnsw_support(internal) RETURNS internal + AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + -- bit operators CREATE OPERATOR <~> ( @@ -340,17 +348,20 @@ CREATE OPERATOR CLASS bit_hamming_ops FOR TYPE bit USING ivfflat AS OPERATOR 1 <~> (bit, bit) FOR ORDER BY float_ops, FUNCTION 1 hamming_distance(bit, bit), - FUNCTION 3 hamming_distance(bit, bit); + FUNCTION 3 hamming_distance(bit, bit), + FUNCTION 6 bit_ivfflat_support(internal); CREATE OPERATOR CLASS bit_hamming_ops FOR TYPE bit USING hnsw AS OPERATOR 1 <~> (bit, bit) FOR ORDER BY float_ops, - FUNCTION 1 hamming_distance(bit, bit); + FUNCTION 1 hamming_distance(bit, bit), + FUNCTION 4 bit_hnsw_support(internal); CREATE OPERATOR CLASS bit_jaccard_ops FOR TYPE bit USING hnsw AS OPERATOR 1 <%> (bit, bit) FOR ORDER BY float_ops, - FUNCTION 1 jaccard_distance(bit, bit); + FUNCTION 1 jaccard_distance(bit, bit), + FUNCTION 4 bit_hnsw_support(internal); -- halfvec type @@ -459,6 +470,12 @@ CREATE FUNCTION halfvec_accum(double precision[], halfvec) RETURNS double precis CREATE FUNCTION halfvec_avg(double precision[]) RETURNS halfvec AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; +CREATE FUNCTION halfvec_ivfflat_support(internal) RETURNS internal + AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + +CREATE FUNCTION halfvec_hnsw_support(internal) RETURNS internal + AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + -- halfvec aggregates CREATE AGGREGATE avg(halfvec) ( @@ -620,7 +637,8 @@ CREATE OPERATOR CLASS halfvec_l2_ops FOR TYPE halfvec USING ivfflat AS OPERATOR 1 <-> (halfvec, halfvec) FOR ORDER BY float_ops, FUNCTION 1 halfvec_l2_squared_distance(halfvec, halfvec), - FUNCTION 3 l2_distance(halfvec, halfvec); + FUNCTION 3 l2_distance(halfvec, halfvec), + FUNCTION 6 halfvec_ivfflat_support(internal); CREATE OPERATOR CLASS halfvec_ip_ops FOR TYPE halfvec USING ivfflat AS @@ -628,7 +646,8 @@ CREATE OPERATOR CLASS halfvec_ip_ops FUNCTION 1 halfvec_negative_inner_product(halfvec, halfvec), FUNCTION 3 halfvec_spherical_distance(halfvec, halfvec), FUNCTION 4 l2_norm(halfvec), - FUNCTION 5 l2_normalize(halfvec); + FUNCTION 5 l2_normalize(halfvec), + FUNCTION 6 halfvec_ivfflat_support(internal); CREATE OPERATOR CLASS halfvec_cosine_ops FOR TYPE halfvec USING ivfflat AS @@ -637,29 +656,34 @@ CREATE OPERATOR CLASS halfvec_cosine_ops FUNCTION 2 l2_norm(halfvec), FUNCTION 3 halfvec_spherical_distance(halfvec, halfvec), FUNCTION 4 l2_norm(halfvec), - FUNCTION 5 l2_normalize(halfvec); + FUNCTION 5 l2_normalize(halfvec), + FUNCTION 6 halfvec_ivfflat_support(internal); CREATE OPERATOR CLASS halfvec_l2_ops FOR TYPE halfvec USING hnsw AS OPERATOR 1 <-> (halfvec, halfvec) FOR ORDER BY float_ops, - FUNCTION 1 halfvec_l2_squared_distance(halfvec, halfvec); + FUNCTION 1 halfvec_l2_squared_distance(halfvec, halfvec), + FUNCTION 4 halfvec_hnsw_support(internal); CREATE OPERATOR CLASS halfvec_ip_ops FOR TYPE halfvec USING hnsw AS OPERATOR 1 <#> (halfvec, halfvec) FOR ORDER BY float_ops, - FUNCTION 1 halfvec_negative_inner_product(halfvec, halfvec); + FUNCTION 1 halfvec_negative_inner_product(halfvec, halfvec), + FUNCTION 4 halfvec_hnsw_support(internal); CREATE OPERATOR CLASS halfvec_cosine_ops FOR TYPE halfvec USING hnsw AS OPERATOR 1 <=> (halfvec, halfvec) FOR ORDER BY float_ops, FUNCTION 1 halfvec_negative_inner_product(halfvec, halfvec), FUNCTION 2 l2_norm(halfvec), - FUNCTION 3 l2_normalize(halfvec); + FUNCTION 3 l2_normalize(halfvec), + FUNCTION 4 halfvec_hnsw_support(internal); CREATE OPERATOR CLASS halfvec_l1_ops FOR TYPE halfvec USING hnsw AS OPERATOR 1 <+> (halfvec, halfvec) FOR ORDER BY float_ops, - FUNCTION 1 l1_distance(halfvec, halfvec); + FUNCTION 1 l1_distance(halfvec, halfvec), + FUNCTION 4 halfvec_hnsw_support(internal); --- sparsevec type @@ -755,6 +779,9 @@ CREATE FUNCTION halfvec_to_sparsevec(halfvec, integer, boolean) RETURNS sparseve CREATE FUNCTION sparsevec_to_halfvec(sparsevec, integer, boolean) RETURNS halfvec AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; +CREATE FUNCTION sparsevec_hnsw_support(internal) RETURNS internal + AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + -- sparsevec casts CREATE CAST (sparsevec AS sparsevec) @@ -844,21 +871,25 @@ CREATE OPERATOR CLASS sparsevec_ops CREATE OPERATOR CLASS sparsevec_l2_ops FOR TYPE sparsevec USING hnsw AS OPERATOR 1 <-> (sparsevec, sparsevec) FOR ORDER BY float_ops, - FUNCTION 1 sparsevec_l2_squared_distance(sparsevec, sparsevec); + FUNCTION 1 sparsevec_l2_squared_distance(sparsevec, sparsevec), + FUNCTION 4 sparsevec_hnsw_support(internal); CREATE OPERATOR CLASS sparsevec_ip_ops FOR TYPE sparsevec USING hnsw AS OPERATOR 1 <#> (sparsevec, sparsevec) FOR ORDER BY float_ops, - FUNCTION 1 sparsevec_negative_inner_product(sparsevec, sparsevec); + FUNCTION 1 sparsevec_negative_inner_product(sparsevec, sparsevec), + FUNCTION 4 sparsevec_hnsw_support(internal); CREATE OPERATOR CLASS sparsevec_cosine_ops FOR TYPE sparsevec USING hnsw AS OPERATOR 1 <=> (sparsevec, sparsevec) FOR ORDER BY float_ops, FUNCTION 1 sparsevec_negative_inner_product(sparsevec, sparsevec), FUNCTION 2 l2_norm(sparsevec), - FUNCTION 3 l2_normalize(sparsevec); + FUNCTION 3 l2_normalize(sparsevec), + FUNCTION 4 sparsevec_hnsw_support(internal); CREATE OPERATOR CLASS sparsevec_l1_ops FOR TYPE sparsevec USING hnsw AS OPERATOR 1 <+> (sparsevec, sparsevec) FOR ORDER BY float_ops, - FUNCTION 1 l1_distance(sparsevec, sparsevec); + FUNCTION 1 l1_distance(sparsevec, sparsevec), + FUNCTION 4 sparsevec_hnsw_support(internal); diff --git a/src/hnsw.c b/src/hnsw.c index b56ab71..9f32260 100644 --- a/src/hnsw.c +++ b/src/hnsw.c @@ -194,7 +194,7 @@ hnswhandler(PG_FUNCTION_ARGS) IndexAmRoutine *amroutine = makeNode(IndexAmRoutine); amroutine->amstrategies = 0; - amroutine->amsupport = 3; + amroutine->amsupport = 4; #if PG_VERSION_NUM >= 130000 amroutine->amoptsprocnum = 0; #endif diff --git a/src/hnsw.h b/src/hnsw.h index d02522b..6072962 100644 --- a/src/hnsw.h +++ b/src/hnsw.h @@ -23,6 +23,7 @@ #define HNSW_DISTANCE_PROC 1 #define HNSW_NORM_PROC 2 #define HNSW_NORMALIZE_PROC 3 +#define HNSW_TYPE_SUPPORT_PROC 4 #define HNSW_VERSION 1 #define HNSW_MAGIC_NUMBER 0xA953A953 @@ -62,7 +63,8 @@ typedef enum HnswType HNSW_TYPE_VECTOR, HNSW_TYPE_HALFVEC, HNSW_TYPE_BIT, - HNSW_TYPE_SPARSEVEC + HNSW_TYPE_SPARSEVEC, + HNSW_TYPE_UNSUPPORTED } HnswType; /* Build phases */ diff --git a/src/hnswutils.c b/src/hnswutils.c index 3773b1d..6fc6dcc 100644 --- a/src/hnswutils.c +++ b/src/hnswutils.c @@ -13,7 +13,6 @@ #include "utils/datum.h" #include "utils/memdebug.h" #include "utils/rel.h" -#include "utils/syscache.h" #if PG_VERSION_NUM >= 130000 #include "common/hashfn.h" @@ -159,32 +158,17 @@ HnswOptionalProcInfo(Relation index, uint16 procnum) HnswType HnswGetType(Relation index) { + FmgrInfo *procinfo = HnswOptionalProcInfo(index, HNSW_TYPE_SUPPORT_PROC); Oid typid = TupleDescAttr(index->rd_att, 0)->atttypid; - HeapTuple tuple; - Form_pg_type type; HnswType result; - if (typid == BITOID) - return HNSW_TYPE_BIT; + if (procinfo == NULL) + return HNSW_TYPE_VECTOR; - tuple = SearchSysCache1(TYPEOID, ObjectIdGetDatum(typid)); - if (!HeapTupleIsValid(tuple)) - elog(ERROR, "cache lookup failed for type %u", typid); + result = (HnswType) DatumGetInt32(FunctionCall1(procinfo, ObjectIdGetDatum(typid))); - type = (Form_pg_type) GETSTRUCT(tuple); - if (strcmp(NameStr(type->typname), "vector") == 0) - result = HNSW_TYPE_VECTOR; - else if (strcmp(NameStr(type->typname), "halfvec") == 0) - result = HNSW_TYPE_HALFVEC; - else if (strcmp(NameStr(type->typname), "sparsevec") == 0) - result = HNSW_TYPE_SPARSEVEC; - else - { - ReleaseSysCache(tuple); + if (result == HNSW_TYPE_UNSUPPORTED) elog(ERROR, "type not supported for hnsw index"); - } - - ReleaseSysCache(tuple); return result; } @@ -1318,3 +1302,29 @@ HnswFindElementNeighbors(char *base, HnswElement element, HnswElement entryPoint ep = w; } } + +PGDLLEXPORT PG_FUNCTION_INFO_V1(halfvec_hnsw_support); +Datum +halfvec_hnsw_support(PG_FUNCTION_ARGS) +{ + PG_RETURN_INT32(HNSW_TYPE_HALFVEC); +}; + +PGDLLEXPORT PG_FUNCTION_INFO_V1(bit_hnsw_support); +Datum +bit_hnsw_support(PG_FUNCTION_ARGS) +{ + Oid typid = PG_GETARG_OID(0); + + if (typid == BITOID) + PG_RETURN_INT32(HNSW_TYPE_BIT); + else + PG_RETURN_INT32(HNSW_TYPE_UNSUPPORTED); +}; + +PGDLLEXPORT PG_FUNCTION_INFO_V1(sparsevec_hnsw_support); +Datum +sparsevec_hnsw_support(PG_FUNCTION_ARGS) +{ + PG_RETURN_INT32(HNSW_TYPE_SPARSEVEC); +}; diff --git a/src/ivfflat.c b/src/ivfflat.c index 53dc766..6bb2422 100644 --- a/src/ivfflat.c +++ b/src/ivfflat.c @@ -188,7 +188,7 @@ ivfflathandler(PG_FUNCTION_ARGS) IndexAmRoutine *amroutine = makeNode(IndexAmRoutine); amroutine->amstrategies = 0; - amroutine->amsupport = 5; + amroutine->amsupport = 6; #if PG_VERSION_NUM >= 130000 amroutine->amoptsprocnum = 0; #endif diff --git a/src/ivfflat.h b/src/ivfflat.h index 1fb873a..3b918cf 100644 --- a/src/ivfflat.h +++ b/src/ivfflat.h @@ -29,6 +29,7 @@ #define IVFFLAT_KMEANS_DISTANCE_PROC 3 #define IVFFLAT_KMEANS_NORM_PROC 4 #define IVFFLAT_NORMALIZE_PROC 5 +#define IVFFLAT_TYPE_SUPPORT_PROC 6 #define IVFFLAT_VERSION 1 #define IVFFLAT_MAGIC_NUMBER 0x14FF1A7 @@ -48,7 +49,8 @@ typedef enum IvfflatType { IVFFLAT_TYPE_VECTOR, IVFFLAT_TYPE_HALFVEC, - IVFFLAT_TYPE_BIT + IVFFLAT_TYPE_BIT, + IVFFLAT_TYPE_UNSUPPORTED } IvfflatType; /* Build phases */ diff --git a/src/ivfutils.c b/src/ivfutils.c index 904ee39..7223ba3 100644 --- a/src/ivfutils.c +++ b/src/ivfutils.c @@ -5,7 +5,6 @@ #include "fmgr.h" #include "ivfflat.h" #include "storage/bufmgr.h" -#include "utils/syscache.h" /* * Allocate a vector array @@ -68,30 +67,17 @@ IvfflatOptionalProcInfo(Relation index, uint16 procnum) IvfflatType IvfflatGetType(Relation index) { + FmgrInfo *procinfo = IvfflatOptionalProcInfo(index, IVFFLAT_TYPE_SUPPORT_PROC); Oid typid = TupleDescAttr(index->rd_att, 0)->atttypid; - HeapTuple tuple; - Form_pg_type type; IvfflatType result; - if (typid == BITOID) - return IVFFLAT_TYPE_BIT; + if (procinfo == NULL) + return IVFFLAT_TYPE_VECTOR; - tuple = SearchSysCache1(TYPEOID, ObjectIdGetDatum(typid)); - if (!HeapTupleIsValid(tuple)) - elog(ERROR, "cache lookup failed for type %u", typid); + result = (IvfflatType) DatumGetInt32(FunctionCall1(procinfo, ObjectIdGetDatum(typid))); - type = (Form_pg_type) GETSTRUCT(tuple); - if (strcmp(NameStr(type->typname), "vector") == 0) - result = IVFFLAT_TYPE_VECTOR; - else if (strcmp(NameStr(type->typname), "halfvec") == 0) - result = IVFFLAT_TYPE_HALFVEC; - else - { - ReleaseSysCache(tuple); + if (result == IVFFLAT_TYPE_UNSUPPORTED) elog(ERROR, "type not supported for ivfflat index"); - } - - ReleaseSysCache(tuple); return result; } @@ -259,3 +245,22 @@ IvfflatUpdateList(Relation index, ListInfo listInfo, UnlockReleaseBuffer(buf); } } + +PGDLLEXPORT PG_FUNCTION_INFO_V1(halfvec_ivfflat_support); +Datum +halfvec_ivfflat_support(PG_FUNCTION_ARGS) +{ + PG_RETURN_INT32(IVFFLAT_TYPE_HALFVEC); +}; + +PGDLLEXPORT PG_FUNCTION_INFO_V1(bit_ivfflat_support); +Datum +bit_ivfflat_support(PG_FUNCTION_ARGS) +{ + Oid typid = PG_GETARG_OID(0); + + if (typid == BITOID) + PG_RETURN_INT32(IVFFLAT_TYPE_BIT); + else + PG_RETURN_INT32(IVFFLAT_TYPE_UNSUPPORTED); +};