diff --git a/sql/vector--0.6.2--0.7.0.sql b/sql/vector--0.6.2--0.7.0.sql index 84c7f3a..2584dd1 100644 --- a/sql/vector--0.6.2--0.7.0.sql +++ b/sql/vector--0.6.2--0.7.0.sql @@ -22,28 +22,10 @@ CREATE OPERATOR || ( LEFTARG = vector, RIGHTARG = vector, PROCEDURE = vector_concat ); -CREATE FUNCTION ivfflat_bit_max_dims(internal) RETURNS internal +CREATE FUNCTION ivfflat_bit_support(internal) RETURNS internal AS 'MODULE_PATHNAME' LANGUAGE C; -CREATE FUNCTION ivfflat_halfvec_max_dims(internal) RETURNS internal - AS 'MODULE_PATHNAME' LANGUAGE C; - -CREATE FUNCTION ivfflat_vector_update_center(internal, internal, internal) RETURNS internal - AS 'MODULE_PATHNAME' LANGUAGE C; - -CREATE FUNCTION ivfflat_bit_update_center(internal, internal, internal) RETURNS internal - AS 'MODULE_PATHNAME' LANGUAGE C; - -CREATE FUNCTION ivfflat_halfvec_update_center(internal, internal, internal) RETURNS internal - AS 'MODULE_PATHNAME' LANGUAGE C; - -CREATE FUNCTION ivfflat_vector_sum_center(internal, internal) RETURNS internal - AS 'MODULE_PATHNAME' LANGUAGE C; - -CREATE FUNCTION ivfflat_bit_sum_center(internal, internal) RETURNS internal - AS 'MODULE_PATHNAME' LANGUAGE C; - -CREATE FUNCTION ivfflat_halfvec_sum_center(internal, internal) RETURNS internal +CREATE FUNCTION ivfflat_halfvec_support(internal) RETURNS internal AS 'MODULE_PATHNAME' LANGUAGE C; CREATE FUNCTION hnsw_bit_max_dims(internal) RETURNS internal @@ -84,9 +66,7 @@ CREATE OPERATOR CLASS bit_hamming_ops OPERATOR 1 <~> (bit, bit) FOR ORDER BY float_ops, FUNCTION 1 hamming_distance(bit, bit), FUNCTION 3 hamming_distance(bit, bit), - FUNCTION 6 ivfflat_bit_max_dims(internal), - FUNCTION 7 ivfflat_bit_update_center(internal, internal, internal), - FUNCTION 8 ivfflat_bit_sum_center(internal, internal); + FUNCTION 6 ivfflat_bit_support(internal); CREATE OPERATOR CLASS bit_hamming_ops FOR TYPE bit USING hnsw AS @@ -353,9 +333,7 @@ CREATE OPERATOR CLASS halfvec_l2_ops OPERATOR 1 <-> (halfvec, halfvec) FOR ORDER BY float_ops, FUNCTION 1 halfvec_l2_squared_distance(halfvec, halfvec), FUNCTION 3 l2_distance(halfvec, halfvec), - FUNCTION 6 ivfflat_halfvec_max_dims(internal), - FUNCTION 7 ivfflat_halfvec_update_center(internal, internal, internal), - FUNCTION 8 ivfflat_halfvec_sum_center(internal, internal); + FUNCTION 6 ivfflat_halfvec_support(internal); CREATE OPERATOR CLASS halfvec_ip_ops FOR TYPE halfvec USING ivfflat AS @@ -364,9 +342,7 @@ CREATE OPERATOR CLASS halfvec_ip_ops FUNCTION 3 halfvec_spherical_distance(halfvec, halfvec), FUNCTION 4 l2_norm(halfvec), FUNCTION 5 l2_normalize(halfvec), - FUNCTION 6 ivfflat_halfvec_max_dims(internal), - FUNCTION 7 ivfflat_halfvec_update_center(internal, internal, internal), - FUNCTION 8 ivfflat_halfvec_sum_center(internal, internal); + FUNCTION 6 ivfflat_halfvec_support(internal); CREATE OPERATOR CLASS halfvec_cosine_ops FOR TYPE halfvec USING ivfflat AS @@ -376,9 +352,7 @@ CREATE OPERATOR CLASS halfvec_cosine_ops FUNCTION 3 halfvec_spherical_distance(halfvec, halfvec), FUNCTION 4 l2_norm(halfvec), FUNCTION 5 l2_normalize(halfvec), - FUNCTION 6 ivfflat_halfvec_max_dims(internal), - FUNCTION 7 ivfflat_halfvec_update_center(internal, internal, internal), - FUNCTION 8 ivfflat_halfvec_sum_center(internal, internal); + FUNCTION 6 ivfflat_halfvec_support(internal); CREATE OPERATOR CLASS halfvec_l2_ops FOR TYPE halfvec USING hnsw AS diff --git a/sql/vector.sql b/sql/vector.sql index 9113b3f..7d1767a 100644 --- a/sql/vector.sql +++ b/sql/vector.sql @@ -263,28 +263,10 @@ COMMENT ON ACCESS METHOD hnsw IS 'hnsw index access method'; -- access method private functions -CREATE FUNCTION ivfflat_bit_max_dims(internal) RETURNS internal +CREATE FUNCTION ivfflat_bit_support(internal) RETURNS internal AS 'MODULE_PATHNAME' LANGUAGE C; -CREATE FUNCTION ivfflat_halfvec_max_dims(internal) RETURNS internal - AS 'MODULE_PATHNAME' LANGUAGE C; - -CREATE FUNCTION ivfflat_vector_update_center(internal, internal, internal) RETURNS internal - AS 'MODULE_PATHNAME' LANGUAGE C; - -CREATE FUNCTION ivfflat_bit_update_center(internal, internal, internal) RETURNS internal - AS 'MODULE_PATHNAME' LANGUAGE C; - -CREATE FUNCTION ivfflat_halfvec_update_center(internal, internal, internal) RETURNS internal - AS 'MODULE_PATHNAME' LANGUAGE C; - -CREATE FUNCTION ivfflat_vector_sum_center(internal, internal) RETURNS internal - AS 'MODULE_PATHNAME' LANGUAGE C; - -CREATE FUNCTION ivfflat_bit_sum_center(internal, internal) RETURNS internal - AS 'MODULE_PATHNAME' LANGUAGE C; - -CREATE FUNCTION ivfflat_halfvec_sum_center(internal, internal) RETURNS internal +CREATE FUNCTION ivfflat_halfvec_support(internal) RETURNS internal AS 'MODULE_PATHNAME' LANGUAGE C; CREATE FUNCTION hnsw_bit_max_dims(internal) RETURNS internal @@ -379,9 +361,7 @@ CREATE OPERATOR CLASS bit_hamming_ops OPERATOR 1 <~> (bit, bit) FOR ORDER BY float_ops, FUNCTION 1 hamming_distance(bit, bit), FUNCTION 3 hamming_distance(bit, bit), - FUNCTION 6 ivfflat_bit_max_dims(internal), - FUNCTION 7 ivfflat_bit_update_center(internal, internal, internal), - FUNCTION 8 ivfflat_bit_sum_center(internal, internal); + FUNCTION 6 ivfflat_bit_support(internal); CREATE OPERATOR CLASS bit_hamming_ops FOR TYPE bit USING hnsw AS @@ -664,9 +644,7 @@ CREATE OPERATOR CLASS halfvec_l2_ops OPERATOR 1 <-> (halfvec, halfvec) FOR ORDER BY float_ops, FUNCTION 1 halfvec_l2_squared_distance(halfvec, halfvec), FUNCTION 3 l2_distance(halfvec, halfvec), - FUNCTION 6 ivfflat_halfvec_max_dims(internal), - FUNCTION 7 ivfflat_halfvec_update_center(internal, internal, internal), - FUNCTION 8 ivfflat_halfvec_sum_center(internal, internal); + FUNCTION 6 ivfflat_halfvec_support(internal); CREATE OPERATOR CLASS halfvec_ip_ops FOR TYPE halfvec USING ivfflat AS @@ -675,9 +653,7 @@ CREATE OPERATOR CLASS halfvec_ip_ops FUNCTION 3 halfvec_spherical_distance(halfvec, halfvec), FUNCTION 4 l2_norm(halfvec), FUNCTION 5 l2_normalize(halfvec), - FUNCTION 6 ivfflat_halfvec_max_dims(internal), - FUNCTION 7 ivfflat_halfvec_update_center(internal, internal, internal), - FUNCTION 8 ivfflat_halfvec_sum_center(internal, internal); + FUNCTION 6 ivfflat_halfvec_support(internal); CREATE OPERATOR CLASS halfvec_cosine_ops FOR TYPE halfvec USING ivfflat AS @@ -687,9 +663,7 @@ CREATE OPERATOR CLASS halfvec_cosine_ops FUNCTION 3 halfvec_spherical_distance(halfvec, halfvec), FUNCTION 4 l2_norm(halfvec), FUNCTION 5 l2_normalize(halfvec), - FUNCTION 6 ivfflat_halfvec_max_dims(internal), - FUNCTION 7 ivfflat_halfvec_update_center(internal, internal, internal), - FUNCTION 8 ivfflat_halfvec_sum_center(internal, internal); + FUNCTION 6 ivfflat_halfvec_support(internal); CREATE OPERATOR CLASS halfvec_l2_ops FOR TYPE halfvec USING hnsw AS diff --git a/src/ivfbuild.c b/src/ivfbuild.c index 0fad43e..d2b9c53 100644 --- a/src/ivfbuild.c +++ b/src/ivfbuild.c @@ -319,44 +319,13 @@ InsertTuples(Relation index, IvfflatBuildState * buildstate, ForkNumber forkNum) } } -/* - * Get max dimensions - */ -static int -GetMaxDimensions(Relation index) -{ - FmgrInfo *procinfo = IvfflatOptionalProcInfo(index, IVFFLAT_MAX_DIMS_PROC); - - if (procinfo == NULL) - return IVFFLAT_MAX_DIM; - - return DatumGetInt32(FunctionCall1(procinfo, PointerGetDatum(NULL))); -} - -/* - * Get item size - */ -static Size -GetItemSize(int maxDimensions, int dimensions) -{ - /* TODO Improve */ - if (maxDimensions == IVFFLAT_MAX_DIM) - return VECTOR_SIZE(dimensions); - else if (maxDimensions == IVFFLAT_MAX_DIM * 2) - return HALFVEC_SIZE(dimensions); - else if (maxDimensions == IVFFLAT_MAX_DIM * 32) - return VARBITTOTALLEN(dimensions); - else - elog(ERROR, "Unsupported type"); -} - /* * Initialize the build state */ static void InitBuildState(IvfflatBuildState * buildstate, Relation heap, Relation index, IndexInfo *indexInfo) { - int maxDimensions; + IvfflatTypeInfo *typeInfo = &buildstate->typeInfo; buildstate->heap = heap; buildstate->index = index; @@ -365,18 +334,19 @@ InitBuildState(IvfflatBuildState * buildstate, Relation heap, Relation index, In buildstate->lists = IvfflatGetLists(index); buildstate->dimensions = TupleDescAttr(index->rd_att, 0)->atttypmod; + typeInfo->dimensions = buildstate->dimensions; + GetTypeInfo(typeInfo, index); + /* Disallow varbit since require fixed dimensions */ if (TupleDescAttr(index->rd_att, 0)->atttypid == VARBITOID) elog(ERROR, "type not supported for ivfflat index"); - maxDimensions = GetMaxDimensions(index); - /* Require column to have dimensions to be indexed */ if (buildstate->dimensions < 0) elog(ERROR, "column does not have dimensions"); - if (buildstate->dimensions > maxDimensions) - elog(ERROR, "column cannot have more than %d dimensions for ivfflat index", maxDimensions); + if (buildstate->dimensions > typeInfo->maxDimensions) + elog(ERROR, "column cannot have more than %d dimensions for ivfflat index", typeInfo->maxDimensions); buildstate->reltuples = 0; buildstate->indtuples = 0; @@ -400,7 +370,7 @@ InitBuildState(IvfflatBuildState * buildstate, Relation heap, Relation index, In buildstate->slot = MakeSingleTupleTableSlot(buildstate->tupdesc, &TTSOpsVirtual); - buildstate->centers = VectorArrayInit(buildstate->lists, buildstate->dimensions, GetItemSize(maxDimensions, buildstate->dimensions)); + buildstate->centers = VectorArrayInit(buildstate->lists, buildstate->dimensions, typeInfo->itemsize); buildstate->listInfo = palloc(sizeof(ListInfo) * buildstate->lists); buildstate->tmpCtx = AllocSetContextCreate(CurrentMemoryContext, @@ -470,7 +440,7 @@ ComputeCenters(IvfflatBuildState * buildstate) } /* Calculate centers */ - IvfflatBench("k-means", IvfflatKmeans(buildstate->index, buildstate->samples, buildstate->centers)); + IvfflatBench("k-means", IvfflatKmeans(buildstate->index, buildstate->samples, buildstate->centers, &buildstate->typeInfo)); /* Free samples before we allocate more memory */ VectorArrayFree(buildstate->samples); diff --git a/src/ivfflat.c b/src/ivfflat.c index 9a226a6..6bb2422 100644 --- a/src/ivfflat.c +++ b/src/ivfflat.c @@ -188,7 +188,7 @@ ivfflathandler(PG_FUNCTION_ARGS) IndexAmRoutine *amroutine = makeNode(IndexAmRoutine); amroutine->amstrategies = 0; - amroutine->amsupport = 8; + amroutine->amsupport = 6; #if PG_VERSION_NUM >= 130000 amroutine->amoptsprocnum = 0; #endif diff --git a/src/ivfflat.h b/src/ivfflat.h index 348d96d..b0c5a58 100644 --- a/src/ivfflat.h +++ b/src/ivfflat.h @@ -29,9 +29,7 @@ #define IVFFLAT_KMEANS_DISTANCE_PROC 3 #define IVFFLAT_KMEANS_NORM_PROC 4 #define IVFFLAT_NORMALIZE_PROC 5 -#define IVFFLAT_MAX_DIMS_PROC 6 -#define IVFFLAT_UPDATE_CENTER_PROC 7 -#define IVFFLAT_SUM_CENTER_PROC 8 +#define IVFFLAT_TYPE_INFO_PROC 6 #define IVFFLAT_VERSION 1 #define IVFFLAT_MAGIC_NUMBER 0x14FF1A7 @@ -152,12 +150,22 @@ typedef struct IvfflatLeader char *ivfcenters; } IvfflatLeader; +typedef struct IvfflatTypeInfo +{ + int dimensions; + int maxDimensions; + int itemsize; + void (*updateCenter) (Pointer v, int dimensions, float *x); + void (*sumCenter) (Pointer v, float *x); +} IvfflatTypeInfo; + typedef struct IvfflatBuildState { /* Info */ Relation heap; Relation index; IndexInfo *indexInfo; + IvfflatTypeInfo typeInfo; /* Settings */ int dimensions; @@ -271,7 +279,7 @@ typedef IvfflatScanOpaqueData * IvfflatScanOpaque; /* Methods */ VectorArray VectorArrayInit(int maxlen, int dimensions, Size itemsize); void VectorArrayFree(VectorArray arr); -void IvfflatKmeans(Relation index, VectorArray samples, VectorArray centers); +void IvfflatKmeans(Relation index, VectorArray samples, VectorArray centers, IvfflatTypeInfo * typeInfo); FmgrInfo *IvfflatOptionalProcInfo(Relation index, uint16 procnum); Datum IvfflatNormValue(FmgrInfo *procinfo, Oid collation, Datum value); bool IvfflatCheckNorm(FmgrInfo *procinfo, Oid collation, Datum value); @@ -284,6 +292,7 @@ Buffer IvfflatNewBuffer(Relation index, ForkNumber forkNum); void IvfflatInitPage(Buffer buf, Page page); void IvfflatInitRegisterPage(Relation index, Buffer *buf, Page *page, GenericXLogState **state); void IvfflatInit(void); +void GetTypeInfo(IvfflatTypeInfo * typeInfo, Relation index); PGDLLEXPORT void IvfflatParallelBuildMain(dsm_segment *seg, shm_toc *toc); /* Index access methods */ diff --git a/src/ivfkmeans.c b/src/ivfkmeans.c index 6ac1ac8..84a2dd3 100644 --- a/src/ivfkmeans.c +++ b/src/ivfkmeans.c @@ -13,17 +13,6 @@ #include "utils/memutils.h" #include "vector.h" -/* Support functions */ -PGDLLEXPORT Datum ivfflat_vector_update_center(PG_FUNCTION_ARGS); -PGDLLEXPORT Datum ivfflat_vector_sum_center(PG_FUNCTION_ARGS); - -typedef struct KmeansState -{ - int dimensions; - FmgrInfo *updatecenterprocinfo; - FmgrInfo *sumcenterprocinfo; -} KmeansState; - /* * Initialize with kmeans++ * @@ -127,20 +116,11 @@ NormCenters(FmgrInfo *normalizeprocinfo, Oid collation, VectorArray centers) MemoryContextDelete(normCtx); } -static void -UpdateCenter(FmgrInfo *procinfo, Pointer center, int dimensions, float *x) -{ - if (procinfo == NULL) - DirectFunctionCall3(ivfflat_vector_update_center, PointerGetDatum(center), Int32GetDatum(dimensions), PointerGetDatum(x)); - else - FunctionCall3(procinfo, PointerGetDatum(center), Int32GetDatum(dimensions), PointerGetDatum(x)); -} - /* * Quick approach if we have no data */ static void -RandomCenters(Relation index, VectorArray centers, KmeansState * kmeansstate) +RandomCenters(Relation index, VectorArray centers, IvfflatTypeInfo * typeInfo) { int dimensions = centers->dim; Oid collation = index->rd_indcollation[0]; @@ -156,7 +136,7 @@ RandomCenters(Relation index, VectorArray centers, KmeansState * kmeansstate) for (int i = 0; i < dimensions; i++) x[i] = (float) RandomDouble(); - UpdateCenter(kmeansstate->updatecenterprocinfo, center, dimensions, x); + typeInfo->updateCenter(center, dimensions, x); centers->length++; } @@ -184,26 +164,17 @@ ShowMemoryUsage(MemoryContext context, Size estimatedSize) } #endif -static void -SumCenter(FmgrInfo *procinfo, Pointer sample, float *x) -{ - if (procinfo == NULL) - DirectFunctionCall2(ivfflat_vector_sum_center, PointerGetDatum(sample), PointerGetDatum(x)); - else - FunctionCall2(procinfo, PointerGetDatum(sample), PointerGetDatum(x)); -} - /* * Sum centers */ static void -SumCenters(VectorArray samples, float *agg, int *closestCenters, KmeansState * kmeansstate) +SumCenters(VectorArray samples, float *agg, int *closestCenters, IvfflatTypeInfo * typeInfo) { for (int j = 0; j < samples->length; j++) { - float *x = agg + ((int64) closestCenters[j] * kmeansstate->dimensions); + float *x = agg + ((int64) closestCenters[j] * samples->dim); - SumCenter(kmeansstate->sumcenterprocinfo, VectorArrayGet(samples, j), x); + typeInfo->sumCenter(VectorArrayGet(samples, j), x); } } @@ -211,13 +182,13 @@ SumCenters(VectorArray samples, float *agg, int *closestCenters, KmeansState * k * Update centers */ static void -UpdateCenters(float *agg, VectorArray centers, KmeansState * kmeansstate) +UpdateCenters(float *agg, VectorArray centers, IvfflatTypeInfo * typeInfo) { for (int j = 0; j < centers->length; j++) { - float *x = agg + ((int64) j * kmeansstate->dimensions); + float *x = agg + ((int64) j * centers->dim); - UpdateCenter(kmeansstate->updatecenterprocinfo, VectorArrayGet(centers, j), centers->dim, x); + typeInfo->updateCenter(VectorArrayGet(centers, j), centers->dim, x); } } @@ -225,9 +196,9 @@ UpdateCenters(float *agg, VectorArray centers, KmeansState * kmeansstate) * Compute new centers */ static void -ComputeNewCenters(VectorArray samples, float *agg, VectorArray newCenters, int *centerCounts, int *closestCenters, FmgrInfo *normprocinfo, FmgrInfo *normalizeprocinfo, Oid collation, KmeansState * kmeansstate) +ComputeNewCenters(VectorArray samples, float *agg, VectorArray newCenters, int *centerCounts, int *closestCenters, FmgrInfo *normprocinfo, FmgrInfo *normalizeprocinfo, Oid collation, IvfflatTypeInfo * typeInfo) { - int dimensions = kmeansstate->dimensions; + int dimensions = newCenters->dim; int numCenters = newCenters->length; int numSamples = samples->length; @@ -243,7 +214,7 @@ ComputeNewCenters(VectorArray samples, float *agg, VectorArray newCenters, int * } /* Increment sum of closest center */ - SumCenters(samples, agg, closestCenters, kmeansstate); + SumCenters(samples, agg, closestCenters, typeInfo); /* Increment count of closest center */ for (int j = 0; j < numSamples; j++) @@ -276,7 +247,7 @@ ComputeNewCenters(VectorArray samples, float *agg, VectorArray newCenters, int * } /* Set new centers */ - UpdateCenters(agg, newCenters, kmeansstate); + UpdateCenters(agg, newCenters, typeInfo); /* Normalize if needed */ if (normprocinfo != NULL) @@ -292,7 +263,7 @@ ComputeNewCenters(VectorArray samples, float *agg, VectorArray newCenters, int * * https://www.aaai.org/Papers/ICML/2003/ICML03-022.pdf */ static void -ElkanKmeans(Relation index, VectorArray samples, VectorArray centers, KmeansState * kmeansstate) +ElkanKmeans(Relation index, VectorArray samples, VectorArray centers, IvfflatTypeInfo * typeInfo) { FmgrInfo *procinfo; FmgrInfo *normprocinfo; @@ -506,7 +477,7 @@ ElkanKmeans(Relation index, VectorArray samples, VectorArray centers, KmeansStat } /* Step 4: For each center c, let m(c) be mean of all points assigned */ - ComputeNewCenters(samples, agg, newCenters, centerCounts, closestCenters, normprocinfo, normalizeprocinfo, collation, kmeansstate); + ComputeNewCenters(samples, agg, newCenters, centerCounts, closestCenters, normprocinfo, normalizeprocinfo, collation, typeInfo); /* Step 5 */ for (int j = 0; j < numCenters; j++) @@ -546,7 +517,7 @@ ElkanKmeans(Relation index, VectorArray samples, VectorArray centers, KmeansStat * Detect issues with centers */ static void -CheckCenters(Relation index, VectorArray centers, KmeansState * kmeansstate) +CheckCenters(Relation index, VectorArray centers, IvfflatTypeInfo * typeInfo) { FmgrInfo *normprocinfo; float *scratch = palloc(sizeof(float) * centers->dim); @@ -560,7 +531,7 @@ CheckCenters(Relation index, VectorArray centers, KmeansState * kmeansstate) for (int j = 0; j < centers->dim; j++) scratch[j] = 0; - SumCenter(kmeansstate->sumcenterprocinfo, VectorArrayGet(centers, i), scratch); + typeInfo->sumCenter(VectorArrayGet(centers, i), scratch); for (int j = 0; j < centers->dim; j++) { @@ -591,29 +562,17 @@ CheckCenters(Relation index, VectorArray centers, KmeansState * kmeansstate) pfree(scratch); } -static void -InitKmeansState(KmeansState * kmeansstate, Relation index, int dimensions) -{ - kmeansstate->dimensions = dimensions; - kmeansstate->updatecenterprocinfo = IvfflatOptionalProcInfo(index, IVFFLAT_UPDATE_CENTER_PROC); - kmeansstate->sumcenterprocinfo = IvfflatOptionalProcInfo(index, IVFFLAT_SUM_CENTER_PROC); -} - /* * Perform naive k-means centering * We use spherical k-means for inner product and cosine */ void -IvfflatKmeans(Relation index, VectorArray samples, VectorArray centers) +IvfflatKmeans(Relation index, VectorArray samples, VectorArray centers, IvfflatTypeInfo * typeInfo) { - KmeansState kmeansstate; - - InitKmeansState(&kmeansstate, index, centers->dim); - if (samples->length == 0) - RandomCenters(index, centers, &kmeansstate); + RandomCenters(index, centers, typeInfo); else - ElkanKmeans(index, samples, centers, &kmeansstate); + ElkanKmeans(index, samples, centers, typeInfo); - CheckCenters(index, centers, &kmeansstate); + CheckCenters(index, centers, typeInfo); } diff --git a/src/ivfutils.c b/src/ivfutils.c index 95e2697..2bc8158 100644 --- a/src/ivfutils.c +++ b/src/ivfutils.c @@ -228,61 +228,34 @@ IvfflatUpdateList(Relation index, ListInfo listInfo, } } -PGDLLEXPORT PG_FUNCTION_INFO_V1(ivfflat_halfvec_max_dims); -Datum -ivfflat_halfvec_max_dims(PG_FUNCTION_ARGS) +static void +VectorUpdateCenter(Pointer v, int dimensions, float *x) { - PG_RETURN_INT32(IVFFLAT_MAX_DIM * 2); -}; - -PGDLLEXPORT PG_FUNCTION_INFO_V1(ivfflat_bit_max_dims); -Datum -ivfflat_bit_max_dims(PG_FUNCTION_ARGS) -{ - PG_RETURN_INT32(IVFFLAT_MAX_DIM * 32); -}; - -PGDLLEXPORT PG_FUNCTION_INFO_V1(ivfflat_vector_update_center); -Datum -ivfflat_vector_update_center(PG_FUNCTION_ARGS) -{ - Vector *vec = PG_GETARG_VECTOR_P(0); - int dimensions = PG_GETARG_INT32(1); - float *x = (float *) PG_GETARG_POINTER(2); + Vector *vec = (Vector *) v; SET_VARSIZE(vec, VECTOR_SIZE(dimensions)); vec->dim = dimensions; for (int k = 0; k < dimensions; k++) vec->x[k] = x[k]; +} - PG_RETURN_VOID(); -}; - -PGDLLEXPORT PG_FUNCTION_INFO_V1(ivfflat_halfvec_update_center); -Datum -ivfflat_halfvec_update_center(PG_FUNCTION_ARGS) +static void +HalfvecUpdateCenter(Pointer v, int dimensions, float *x) { - HalfVector *vec = PG_GETARG_HALFVEC_P(0); - int dimensions = PG_GETARG_INT32(1); - float *x = (float *) PG_GETARG_POINTER(2); + HalfVector *vec = (HalfVector *) v; SET_VARSIZE(vec, HALFVEC_SIZE(dimensions)); vec->dim = dimensions; for (int k = 0; k < dimensions; k++) vec->x[k] = Float4ToHalfUnchecked(x[k]); +} - PG_RETURN_VOID(); -}; - -PGDLLEXPORT PG_FUNCTION_INFO_V1(ivfflat_bit_update_center); -Datum -ivfflat_bit_update_center(PG_FUNCTION_ARGS) +static void +BitUpdateCenter(Pointer v, int dimensions, float *x) { - VarBit *vec = PG_GETARG_VARBIT_P(0); - int dimensions = PG_GETARG_INT32(1); - float *x = (float *) PG_GETARG_POINTER(2); + VarBit *vec = (VarBit *) v; unsigned char *nx = VARBITS(vec); SET_VARSIZE(vec, VARBITTOTALLEN(dimensions)); @@ -293,45 +266,78 @@ ivfflat_bit_update_center(PG_FUNCTION_ARGS) for (int k = 0; k < dimensions; k++) nx[k / 8] |= (x[k] > 0.5 ? 1 : 0) << (7 - (k % 8)); +} - PG_RETURN_VOID(); -}; - -PGDLLEXPORT PG_FUNCTION_INFO_V1(ivfflat_vector_sum_center); -Datum -ivfflat_vector_sum_center(PG_FUNCTION_ARGS) +static void +VectorSumCenter(Pointer v, float *x) { - Vector *vec = PG_GETARG_VECTOR_P(0); - float *x = (float *) PG_GETARG_POINTER(1); + Vector *vec = (Vector *) v; for (int k = 0; k < vec->dim; k++) x[k] += vec->x[k]; +} + +static void +HalfvecSumCenter(Pointer v, float *x) +{ + HalfVector *vec = (HalfVector *) v; + + for (int k = 0; k < vec->dim; k++) + x[k] += HalfToFloat4(vec->x[k]); +} + +static void +BitSumCenter(Pointer v, float *x) +{ + VarBit *vec = (VarBit *) v; + + for (int k = 0; k < VARBITLEN(vec); k++) + x[k] += (float) (((VARBITS(vec)[k / 8]) >> (7 - (k % 8))) & 0x01); +} + +/* + * Get type info + */ +void +GetTypeInfo(IvfflatTypeInfo * typeInfo, Relation index) +{ + FmgrInfo *procinfo = IvfflatOptionalProcInfo(index, IVFFLAT_TYPE_INFO_PROC); + + if (procinfo == NULL) + { + typeInfo->maxDimensions = IVFFLAT_MAX_DIM; + typeInfo->itemsize = VECTOR_SIZE(typeInfo->dimensions); + typeInfo->updateCenter = VectorUpdateCenter; + typeInfo->sumCenter = VectorSumCenter; + } + else + FunctionCall1(procinfo, PointerGetDatum(typeInfo)); +} + +PGDLLEXPORT PG_FUNCTION_INFO_V1(ivfflat_halfvec_support); +Datum +ivfflat_halfvec_support(PG_FUNCTION_ARGS) +{ + IvfflatTypeInfo *typeInfo = (IvfflatTypeInfo *) PG_GETARG_POINTER(0); + + typeInfo->maxDimensions = IVFFLAT_MAX_DIM * 2; + typeInfo->itemsize = HALFVEC_SIZE(typeInfo->dimensions); + typeInfo->updateCenter = HalfvecUpdateCenter; + typeInfo->sumCenter = HalfvecSumCenter; PG_RETURN_VOID(); }; -PGDLLEXPORT PG_FUNCTION_INFO_V1(ivfflat_halfvec_sum_center); +PGDLLEXPORT PG_FUNCTION_INFO_V1(ivfflat_bit_support); Datum -ivfflat_halfvec_sum_center(PG_FUNCTION_ARGS) +ivfflat_bit_support(PG_FUNCTION_ARGS) { - HalfVector *vec = PG_GETARG_HALFVEC_P(0); - float *x = (float *) PG_GETARG_POINTER(1); + IvfflatTypeInfo *typeInfo = (IvfflatTypeInfo *) PG_GETARG_POINTER(0); - for (int k = 0; k < vec->dim; k++) - x[k] += HalfToFloat4(vec->x[k]); + typeInfo->maxDimensions = IVFFLAT_MAX_DIM * 32; + typeInfo->itemsize = VARBITTOTALLEN(typeInfo->dimensions); + typeInfo->updateCenter = BitUpdateCenter; + typeInfo->sumCenter = BitSumCenter; PG_RETURN_VOID(); -} - -PGDLLEXPORT PG_FUNCTION_INFO_V1(ivfflat_bit_sum_center); -Datum -ivfflat_bit_sum_center(PG_FUNCTION_ARGS) -{ - VarBit *vec = PG_GETARG_VARBIT_P(0); - float *x = (float *) PG_GETARG_POINTER(1); - - for (int k = 0; k < VARBITLEN(vec); k++) - x[k] += (float) (((VARBITS(vec)[k / 8]) >> (7 - (k % 8))) & 0x01); - - PG_RETURN_VOID(); -} +};