From ec640f3b579927e77cdaf5758e8429529a514dcc Mon Sep 17 00:00:00 2001 From: Andrew Kane Date: Thu, 25 Apr 2024 12:30:49 -0700 Subject: [PATCH] Switched to static const for IVFFlat type info --- src/ivfbuild.c | 15 ++++++--------- src/ivfflat.h | 6 +++--- src/ivfkmeans.c | 14 +++++++------- src/ivfutils.c | 43 ++++++++++++++++++++++--------------------- 4 files changed, 38 insertions(+), 40 deletions(-) diff --git a/src/ivfbuild.c b/src/ivfbuild.c index b0b2d7f..7b87962 100644 --- a/src/ivfbuild.c +++ b/src/ivfbuild.c @@ -325,18 +325,14 @@ InsertTuples(Relation index, IvfflatBuildState * buildstate, ForkNumber forkNum) static void InitBuildState(IvfflatBuildState * buildstate, Relation heap, Relation index, IndexInfo *indexInfo) { - IvfflatTypeInfo *typeInfo = &buildstate->typeInfo; - buildstate->heap = heap; buildstate->index = index; buildstate->indexInfo = indexInfo; + buildstate->typeInfo = IvfflatGetTypeInfo(index); buildstate->lists = IvfflatGetLists(index); buildstate->dimensions = TupleDescAttr(index->rd_att, 0)->atttypmod; - typeInfo->dimensions = buildstate->dimensions; - IvfflatGetTypeInfo(typeInfo, index); - /* Disallow varbit since require fixed dimensions */ if (TupleDescAttr(index->rd_att, 0)->atttypid == VARBITOID) elog(ERROR, "type not supported for ivfflat index"); @@ -345,8 +341,8 @@ InitBuildState(IvfflatBuildState * buildstate, Relation heap, Relation index, In if (buildstate->dimensions < 0) elog(ERROR, "column does not have dimensions"); - if (buildstate->dimensions > typeInfo->maxDimensions) - elog(ERROR, "column cannot have more than %d dimensions for ivfflat index", typeInfo->maxDimensions); + if (buildstate->dimensions > buildstate->typeInfo->maxDimensions) + elog(ERROR, "column cannot have more than %d dimensions for ivfflat index", buildstate->typeInfo->maxDimensions); buildstate->reltuples = 0; buildstate->indtuples = 0; @@ -370,7 +366,8 @@ InitBuildState(IvfflatBuildState * buildstate, Relation heap, Relation index, In buildstate->slot = MakeSingleTupleTableSlot(buildstate->tupdesc, &TTSOpsVirtual); - buildstate->centers = VectorArrayInit(buildstate->lists, buildstate->dimensions, typeInfo->itemsize); + /* TODO Fix item size */ + buildstate->centers = VectorArrayInit(buildstate->lists, buildstate->dimensions, VECTOR_SIZE(buildstate->dimensions)); buildstate->listInfo = palloc(sizeof(ListInfo) * buildstate->lists); buildstate->tmpCtx = AllocSetContextCreate(CurrentMemoryContext, @@ -440,7 +437,7 @@ ComputeCenters(IvfflatBuildState * buildstate) } /* Calculate centers */ - IvfflatBench("k-means", IvfflatKmeans(buildstate->index, buildstate->samples, buildstate->centers, &buildstate->typeInfo)); + IvfflatBench("k-means", IvfflatKmeans(buildstate->index, buildstate->samples, buildstate->centers, buildstate->typeInfo)); /* Free samples before we allocate more memory */ VectorArrayFree(buildstate->samples); diff --git a/src/ivfflat.h b/src/ivfflat.h index 32b1afd..4b840b6 100644 --- a/src/ivfflat.h +++ b/src/ivfflat.h @@ -165,7 +165,7 @@ typedef struct IvfflatBuildState Relation heap; Relation index; IndexInfo *indexInfo; - IvfflatTypeInfo typeInfo; + const IvfflatTypeInfo *typeInfo; /* Settings */ int dimensions; @@ -279,7 +279,7 @@ typedef IvfflatScanOpaqueData * IvfflatScanOpaque; /* Methods */ VectorArray VectorArrayInit(int maxlen, int dimensions, Size itemsize); void VectorArrayFree(VectorArray arr); -void IvfflatKmeans(Relation index, VectorArray samples, VectorArray centers, IvfflatTypeInfo * typeInfo); +void IvfflatKmeans(Relation index, VectorArray samples, VectorArray centers, const IvfflatTypeInfo * typeInfo); FmgrInfo *IvfflatOptionalProcInfo(Relation index, uint16 procnum); Datum IvfflatNormValue(FmgrInfo *procinfo, Oid collation, Datum value); bool IvfflatCheckNorm(FmgrInfo *procinfo, Oid collation, Datum value); @@ -292,7 +292,7 @@ Buffer IvfflatNewBuffer(Relation index, ForkNumber forkNum); void IvfflatInitPage(Buffer buf, Page page); void IvfflatInitRegisterPage(Relation index, Buffer *buf, Page *page, GenericXLogState **state); void IvfflatInit(void); -void IvfflatGetTypeInfo(IvfflatTypeInfo * typeInfo, Relation index); +const IvfflatTypeInfo *IvfflatGetTypeInfo(Relation index); PGDLLEXPORT void IvfflatParallelBuildMain(dsm_segment *seg, shm_toc *toc); /* Index access methods */ diff --git a/src/ivfkmeans.c b/src/ivfkmeans.c index b9170a0..89b5cf7 100644 --- a/src/ivfkmeans.c +++ b/src/ivfkmeans.c @@ -120,7 +120,7 @@ NormCenters(FmgrInfo *normalizeprocinfo, Oid collation, VectorArray centers) * Quick approach if we have no data */ static void -RandomCenters(Relation index, VectorArray centers, IvfflatTypeInfo * typeInfo) +RandomCenters(Relation index, VectorArray centers, const IvfflatTypeInfo * typeInfo) { int dimensions = centers->dim; Oid collation = index->rd_indcollation[0]; @@ -168,7 +168,7 @@ ShowMemoryUsage(MemoryContext context, Size estimatedSize) * Sum centers */ static void -SumCenters(VectorArray samples, float *agg, int *closestCenters, IvfflatTypeInfo * typeInfo) +SumCenters(VectorArray samples, float *agg, int *closestCenters, const IvfflatTypeInfo * typeInfo) { for (int j = 0; j < samples->length; j++) { @@ -182,7 +182,7 @@ SumCenters(VectorArray samples, float *agg, int *closestCenters, IvfflatTypeInfo * Update centers */ static void -UpdateCenters(float *agg, VectorArray centers, IvfflatTypeInfo * typeInfo) +UpdateCenters(float *agg, VectorArray centers, const IvfflatTypeInfo * typeInfo) { for (int j = 0; j < centers->length; j++) { @@ -196,7 +196,7 @@ UpdateCenters(float *agg, VectorArray centers, IvfflatTypeInfo * typeInfo) * Compute new centers */ static void -ComputeNewCenters(VectorArray samples, float *agg, VectorArray newCenters, int *centerCounts, int *closestCenters, FmgrInfo *normprocinfo, FmgrInfo *normalizeprocinfo, Oid collation, IvfflatTypeInfo * typeInfo) +ComputeNewCenters(VectorArray samples, float *agg, VectorArray newCenters, int *centerCounts, int *closestCenters, FmgrInfo *normprocinfo, FmgrInfo *normalizeprocinfo, Oid collation, const IvfflatTypeInfo * typeInfo) { int dimensions = newCenters->dim; int numCenters = newCenters->length; @@ -263,7 +263,7 @@ ComputeNewCenters(VectorArray samples, float *agg, VectorArray newCenters, int * * https://www.aaai.org/Papers/ICML/2003/ICML03-022.pdf */ static void -ElkanKmeans(Relation index, VectorArray samples, VectorArray centers, IvfflatTypeInfo * typeInfo) +ElkanKmeans(Relation index, VectorArray samples, VectorArray centers, const IvfflatTypeInfo * typeInfo) { FmgrInfo *procinfo; FmgrInfo *normprocinfo; @@ -517,7 +517,7 @@ ElkanKmeans(Relation index, VectorArray samples, VectorArray centers, IvfflatTyp * Detect issues with centers */ static void -CheckCenters(Relation index, VectorArray centers, IvfflatTypeInfo * typeInfo) +CheckCenters(Relation index, VectorArray centers, const IvfflatTypeInfo * typeInfo) { FmgrInfo *normprocinfo; float *scratch = palloc(sizeof(float) * centers->dim); @@ -568,7 +568,7 @@ CheckCenters(Relation index, VectorArray centers, IvfflatTypeInfo * typeInfo) * We use spherical k-means for inner product and cosine */ void -IvfflatKmeans(Relation index, VectorArray samples, VectorArray centers, IvfflatTypeInfo * typeInfo) +IvfflatKmeans(Relation index, VectorArray samples, VectorArray centers, const IvfflatTypeInfo * typeInfo) { if (samples->length == 0) RandomCenters(index, centers, typeInfo); diff --git a/src/ivfutils.c b/src/ivfutils.c index 67bd4c5..7b5f152 100644 --- a/src/ivfutils.c +++ b/src/ivfutils.c @@ -298,46 +298,47 @@ BitSumCenter(Pointer v, float *x) /* * Get type info */ -void -IvfflatGetTypeInfo(IvfflatTypeInfo * typeInfo, Relation index) +const IvfflatTypeInfo * +IvfflatGetTypeInfo(Relation index) { FmgrInfo *procinfo = IvfflatOptionalProcInfo(index, IVFFLAT_TYPE_INFO_PROC); if (procinfo == NULL) { - typeInfo->maxDimensions = IVFFLAT_MAX_DIM; - typeInfo->itemsize = VECTOR_SIZE(typeInfo->dimensions); - typeInfo->updateCenter = VectorUpdateCenter; - typeInfo->sumCenter = VectorSumCenter; + static const IvfflatTypeInfo typeInfo = { + .maxDimensions = IVFFLAT_MAX_DIM, + .updateCenter = VectorUpdateCenter, + .sumCenter = VectorSumCenter + }; + + return (&typeInfo); } else - FunctionCall1(procinfo, PointerGetDatum(typeInfo)); + return (const IvfflatTypeInfo *) DatumGetPointer(FunctionCall0Coll(procinfo, InvalidOid)); } PGDLLEXPORT PG_FUNCTION_INFO_V1(ivfflat_halfvec_support); Datum ivfflat_halfvec_support(PG_FUNCTION_ARGS) { - IvfflatTypeInfo *typeInfo = (IvfflatTypeInfo *) PG_GETARG_POINTER(0); + static const IvfflatTypeInfo typeInfo = { + .maxDimensions = IVFFLAT_MAX_DIM * 2, + .updateCenter = HalfvecUpdateCenter, + .sumCenter = HalfvecSumCenter + }; - typeInfo->maxDimensions = IVFFLAT_MAX_DIM * 2; - typeInfo->itemsize = HALFVEC_SIZE(typeInfo->dimensions); - typeInfo->updateCenter = HalfvecUpdateCenter; - typeInfo->sumCenter = HalfvecSumCenter; - - PG_RETURN_VOID(); + PG_RETURN_POINTER(&typeInfo); }; PGDLLEXPORT PG_FUNCTION_INFO_V1(ivfflat_bit_support); Datum ivfflat_bit_support(PG_FUNCTION_ARGS) { - IvfflatTypeInfo *typeInfo = (IvfflatTypeInfo *) PG_GETARG_POINTER(0); + static const IvfflatTypeInfo typeInfo = { + .maxDimensions = IVFFLAT_MAX_DIM * 32, + .updateCenter = BitUpdateCenter, + .sumCenter = BitSumCenter + }; - typeInfo->maxDimensions = IVFFLAT_MAX_DIM * 32; - typeInfo->itemsize = VARBITTOTALLEN(typeInfo->dimensions); - typeInfo->updateCenter = BitUpdateCenter; - typeInfo->sumCenter = BitSumCenter; - - PG_RETURN_VOID(); + PG_RETURN_POINTER(&typeInfo); };