Switched to static const for IVFFlat type info

This commit is contained in:
Andrew Kane
2024-04-25 12:30:49 -07:00
parent 91cf4d223e
commit ec640f3b57
4 changed files with 38 additions and 40 deletions

View File

@@ -325,18 +325,14 @@ InsertTuples(Relation index, IvfflatBuildState * buildstate, ForkNumber forkNum)
static void
InitBuildState(IvfflatBuildState * buildstate, Relation heap, Relation index, IndexInfo *indexInfo)
{
IvfflatTypeInfo *typeInfo = &buildstate->typeInfo;
buildstate->heap = heap;
buildstate->index = index;
buildstate->indexInfo = indexInfo;
buildstate->typeInfo = IvfflatGetTypeInfo(index);
buildstate->lists = IvfflatGetLists(index);
buildstate->dimensions = TupleDescAttr(index->rd_att, 0)->atttypmod;
typeInfo->dimensions = buildstate->dimensions;
IvfflatGetTypeInfo(typeInfo, index);
/* Disallow varbit since require fixed dimensions */
if (TupleDescAttr(index->rd_att, 0)->atttypid == VARBITOID)
elog(ERROR, "type not supported for ivfflat index");
@@ -345,8 +341,8 @@ InitBuildState(IvfflatBuildState * buildstate, Relation heap, Relation index, In
if (buildstate->dimensions < 0)
elog(ERROR, "column does not have dimensions");
if (buildstate->dimensions > typeInfo->maxDimensions)
elog(ERROR, "column cannot have more than %d dimensions for ivfflat index", typeInfo->maxDimensions);
if (buildstate->dimensions > buildstate->typeInfo->maxDimensions)
elog(ERROR, "column cannot have more than %d dimensions for ivfflat index", buildstate->typeInfo->maxDimensions);
buildstate->reltuples = 0;
buildstate->indtuples = 0;
@@ -370,7 +366,8 @@ InitBuildState(IvfflatBuildState * buildstate, Relation heap, Relation index, In
buildstate->slot = MakeSingleTupleTableSlot(buildstate->tupdesc, &TTSOpsVirtual);
buildstate->centers = VectorArrayInit(buildstate->lists, buildstate->dimensions, typeInfo->itemsize);
/* TODO Fix item size */
buildstate->centers = VectorArrayInit(buildstate->lists, buildstate->dimensions, VECTOR_SIZE(buildstate->dimensions));
buildstate->listInfo = palloc(sizeof(ListInfo) * buildstate->lists);
buildstate->tmpCtx = AllocSetContextCreate(CurrentMemoryContext,
@@ -440,7 +437,7 @@ ComputeCenters(IvfflatBuildState * buildstate)
}
/* Calculate centers */
IvfflatBench("k-means", IvfflatKmeans(buildstate->index, buildstate->samples, buildstate->centers, &buildstate->typeInfo));
IvfflatBench("k-means", IvfflatKmeans(buildstate->index, buildstate->samples, buildstate->centers, buildstate->typeInfo));
/* Free samples before we allocate more memory */
VectorArrayFree(buildstate->samples);

View File

@@ -165,7 +165,7 @@ typedef struct IvfflatBuildState
Relation heap;
Relation index;
IndexInfo *indexInfo;
IvfflatTypeInfo typeInfo;
const IvfflatTypeInfo *typeInfo;
/* Settings */
int dimensions;
@@ -279,7 +279,7 @@ typedef IvfflatScanOpaqueData * IvfflatScanOpaque;
/* Methods */
VectorArray VectorArrayInit(int maxlen, int dimensions, Size itemsize);
void VectorArrayFree(VectorArray arr);
void IvfflatKmeans(Relation index, VectorArray samples, VectorArray centers, IvfflatTypeInfo * typeInfo);
void IvfflatKmeans(Relation index, VectorArray samples, VectorArray centers, const IvfflatTypeInfo * typeInfo);
FmgrInfo *IvfflatOptionalProcInfo(Relation index, uint16 procnum);
Datum IvfflatNormValue(FmgrInfo *procinfo, Oid collation, Datum value);
bool IvfflatCheckNorm(FmgrInfo *procinfo, Oid collation, Datum value);
@@ -292,7 +292,7 @@ Buffer IvfflatNewBuffer(Relation index, ForkNumber forkNum);
void IvfflatInitPage(Buffer buf, Page page);
void IvfflatInitRegisterPage(Relation index, Buffer *buf, Page *page, GenericXLogState **state);
void IvfflatInit(void);
void IvfflatGetTypeInfo(IvfflatTypeInfo * typeInfo, Relation index);
const IvfflatTypeInfo *IvfflatGetTypeInfo(Relation index);
PGDLLEXPORT void IvfflatParallelBuildMain(dsm_segment *seg, shm_toc *toc);
/* Index access methods */

View File

@@ -120,7 +120,7 @@ NormCenters(FmgrInfo *normalizeprocinfo, Oid collation, VectorArray centers)
* Quick approach if we have no data
*/
static void
RandomCenters(Relation index, VectorArray centers, IvfflatTypeInfo * typeInfo)
RandomCenters(Relation index, VectorArray centers, const IvfflatTypeInfo * typeInfo)
{
int dimensions = centers->dim;
Oid collation = index->rd_indcollation[0];
@@ -168,7 +168,7 @@ ShowMemoryUsage(MemoryContext context, Size estimatedSize)
* Sum centers
*/
static void
SumCenters(VectorArray samples, float *agg, int *closestCenters, IvfflatTypeInfo * typeInfo)
SumCenters(VectorArray samples, float *agg, int *closestCenters, const IvfflatTypeInfo * typeInfo)
{
for (int j = 0; j < samples->length; j++)
{
@@ -182,7 +182,7 @@ SumCenters(VectorArray samples, float *agg, int *closestCenters, IvfflatTypeInfo
* Update centers
*/
static void
UpdateCenters(float *agg, VectorArray centers, IvfflatTypeInfo * typeInfo)
UpdateCenters(float *agg, VectorArray centers, const IvfflatTypeInfo * typeInfo)
{
for (int j = 0; j < centers->length; j++)
{
@@ -196,7 +196,7 @@ UpdateCenters(float *agg, VectorArray centers, IvfflatTypeInfo * typeInfo)
* Compute new centers
*/
static void
ComputeNewCenters(VectorArray samples, float *agg, VectorArray newCenters, int *centerCounts, int *closestCenters, FmgrInfo *normprocinfo, FmgrInfo *normalizeprocinfo, Oid collation, IvfflatTypeInfo * typeInfo)
ComputeNewCenters(VectorArray samples, float *agg, VectorArray newCenters, int *centerCounts, int *closestCenters, FmgrInfo *normprocinfo, FmgrInfo *normalizeprocinfo, Oid collation, const IvfflatTypeInfo * typeInfo)
{
int dimensions = newCenters->dim;
int numCenters = newCenters->length;
@@ -263,7 +263,7 @@ ComputeNewCenters(VectorArray samples, float *agg, VectorArray newCenters, int *
* https://www.aaai.org/Papers/ICML/2003/ICML03-022.pdf
*/
static void
ElkanKmeans(Relation index, VectorArray samples, VectorArray centers, IvfflatTypeInfo * typeInfo)
ElkanKmeans(Relation index, VectorArray samples, VectorArray centers, const IvfflatTypeInfo * typeInfo)
{
FmgrInfo *procinfo;
FmgrInfo *normprocinfo;
@@ -517,7 +517,7 @@ ElkanKmeans(Relation index, VectorArray samples, VectorArray centers, IvfflatTyp
* Detect issues with centers
*/
static void
CheckCenters(Relation index, VectorArray centers, IvfflatTypeInfo * typeInfo)
CheckCenters(Relation index, VectorArray centers, const IvfflatTypeInfo * typeInfo)
{
FmgrInfo *normprocinfo;
float *scratch = palloc(sizeof(float) * centers->dim);
@@ -568,7 +568,7 @@ CheckCenters(Relation index, VectorArray centers, IvfflatTypeInfo * typeInfo)
* We use spherical k-means for inner product and cosine
*/
void
IvfflatKmeans(Relation index, VectorArray samples, VectorArray centers, IvfflatTypeInfo * typeInfo)
IvfflatKmeans(Relation index, VectorArray samples, VectorArray centers, const IvfflatTypeInfo * typeInfo)
{
if (samples->length == 0)
RandomCenters(index, centers, typeInfo);

View File

@@ -298,46 +298,47 @@ BitSumCenter(Pointer v, float *x)
/*
* Get type info
*/
void
IvfflatGetTypeInfo(IvfflatTypeInfo * typeInfo, Relation index)
const IvfflatTypeInfo *
IvfflatGetTypeInfo(Relation index)
{
FmgrInfo *procinfo = IvfflatOptionalProcInfo(index, IVFFLAT_TYPE_INFO_PROC);
if (procinfo == NULL)
{
typeInfo->maxDimensions = IVFFLAT_MAX_DIM;
typeInfo->itemsize = VECTOR_SIZE(typeInfo->dimensions);
typeInfo->updateCenter = VectorUpdateCenter;
typeInfo->sumCenter = VectorSumCenter;
static const IvfflatTypeInfo typeInfo = {
.maxDimensions = IVFFLAT_MAX_DIM,
.updateCenter = VectorUpdateCenter,
.sumCenter = VectorSumCenter
};
return (&typeInfo);
}
else
FunctionCall1(procinfo, PointerGetDatum(typeInfo));
return (const IvfflatTypeInfo *) DatumGetPointer(FunctionCall0Coll(procinfo, InvalidOid));
}
PGDLLEXPORT PG_FUNCTION_INFO_V1(ivfflat_halfvec_support);
Datum
ivfflat_halfvec_support(PG_FUNCTION_ARGS)
{
IvfflatTypeInfo *typeInfo = (IvfflatTypeInfo *) PG_GETARG_POINTER(0);
static const IvfflatTypeInfo typeInfo = {
.maxDimensions = IVFFLAT_MAX_DIM * 2,
.updateCenter = HalfvecUpdateCenter,
.sumCenter = HalfvecSumCenter
};
typeInfo->maxDimensions = IVFFLAT_MAX_DIM * 2;
typeInfo->itemsize = HALFVEC_SIZE(typeInfo->dimensions);
typeInfo->updateCenter = HalfvecUpdateCenter;
typeInfo->sumCenter = HalfvecSumCenter;
PG_RETURN_VOID();
PG_RETURN_POINTER(&typeInfo);
};
PGDLLEXPORT PG_FUNCTION_INFO_V1(ivfflat_bit_support);
Datum
ivfflat_bit_support(PG_FUNCTION_ARGS)
{
IvfflatTypeInfo *typeInfo = (IvfflatTypeInfo *) PG_GETARG_POINTER(0);
static const IvfflatTypeInfo typeInfo = {
.maxDimensions = IVFFLAT_MAX_DIM * 32,
.updateCenter = BitUpdateCenter,
.sumCenter = BitSumCenter
};
typeInfo->maxDimensions = IVFFLAT_MAX_DIM * 32;
typeInfo->itemsize = VARBITTOTALLEN(typeInfo->dimensions);
typeInfo->updateCenter = BitUpdateCenter;
typeInfo->sumCenter = BitSumCenter;
PG_RETURN_VOID();
PG_RETURN_POINTER(&typeInfo);
};