diff --git a/src/ivfbuild.c b/src/ivfbuild.c index 6f582fd..d2427a4 100644 --- a/src/ivfbuild.c +++ b/src/ivfbuild.c @@ -388,9 +388,11 @@ InitBuildState(IvfflatBuildState * buildstate, Relation heap, Relation index, In buildstate->slot = MakeSingleTupleTableSlot(buildstate->sortdesc, &TTSOpsVirtual); + buildstate->memoryUsed = 0; buildstate->itemsize = buildstate->typeInfo->itemSize(buildstate->dimensions); /* TODO Ensure within maintenance_work_mem */ + buildstate->memoryUsed += VECTOR_ARRAY_SIZE(buildstate->lists, buildstate->itemsize); buildstate->centers = VectorArrayInit(buildstate->lists, buildstate->dimensions, buildstate->itemsize); /* TODO Move allocation to page creation */ @@ -448,6 +450,7 @@ ComputeCenters(IvfflatBuildState * buildstate) /* Sample rows */ /* TODO Ensure within maintenance_work_mem */ + buildstate->memoryUsed += VECTOR_ARRAY_SIZE(numSamples, buildstate->itemsize); buildstate->samples = VectorArrayInit(numSamples, buildstate->dimensions, buildstate->itemsize); if (buildstate->heap != NULL) { @@ -463,7 +466,7 @@ ComputeCenters(IvfflatBuildState * buildstate) } /* Calculate centers */ - IvfflatBench("k-means", IvfflatKmeans(buildstate->index, buildstate->samples, buildstate->centers, buildstate->typeInfo)); + IvfflatBench("k-means", IvfflatKmeans(buildstate->index, buildstate->samples, buildstate->centers, buildstate->typeInfo, buildstate->memoryUsed)); /* Free samples before we allocate more memory */ VectorArrayFree(buildstate->samples); diff --git a/src/ivfflat.h b/src/ivfflat.h index c41e90c..ef39f3a 100644 --- a/src/ivfflat.h +++ b/src/ivfflat.h @@ -224,6 +224,7 @@ typedef struct IvfflatBuildState TupleTableSlot *slot; /* Memory */ + Size memoryUsed; MemoryContext tmpCtx; /* Parallel builds */ @@ -321,7 +322,7 @@ VectorArraySet(VectorArray arr, int offset, Pointer val) /* Methods */ VectorArray VectorArrayInit(int maxlen, int dimensions, Size itemsize); void VectorArrayFree(VectorArray arr); -void IvfflatKmeans(Relation index, VectorArray samples, VectorArray centers, const IvfflatTypeInfo * typeInfo); +void IvfflatKmeans(Relation index, VectorArray samples, VectorArray centers, const IvfflatTypeInfo * typeInfo, Size memoryUsed); FmgrInfo *IvfflatOptionalProcInfo(Relation index, uint16 procnum); Datum IvfflatNormValue(const IvfflatTypeInfo * typeInfo, Oid collation, Datum value); bool IvfflatCheckNorm(FmgrInfo *procinfo, Oid collation, Datum value); diff --git a/src/ivfkmeans.c b/src/ivfkmeans.c index 0107e13..ee647c1 100644 --- a/src/ivfkmeans.c +++ b/src/ivfkmeans.c @@ -244,7 +244,7 @@ ComputeNewCenters(VectorArray samples, float *agg, VectorArray newCenters, int * * https://www.aaai.org/Papers/ICML/2003/ICML03-022.pdf */ static void -ElkanKmeans(Relation index, VectorArray samples, VectorArray centers, const IvfflatTypeInfo * typeInfo) +ElkanKmeans(Relation index, VectorArray samples, VectorArray centers, const IvfflatTypeInfo * typeInfo, Size memoryUsed) { FmgrInfo *procinfo; FmgrInfo *normprocinfo; @@ -263,8 +263,6 @@ ElkanKmeans(Relation index, VectorArray samples, VectorArray centers, const Ivff float *newcdist; /* Calculate allocation sizes */ - Size samplesSize = VECTOR_ARRAY_SIZE(samples->maxlen, samples->itemsize); - Size centersSize = VECTOR_ARRAY_SIZE(centers->maxlen, centers->itemsize); Size newCentersSize = VECTOR_ARRAY_SIZE(numCenters, centers->itemsize); Size aggSize = sizeof(float) * (int64) numCenters * dimensions; Size centerCountsSize = sizeof(int) * numCenters; @@ -276,7 +274,7 @@ ElkanKmeans(Relation index, VectorArray samples, VectorArray centers, const Ivff Size newcdistSize = sizeof(float) * numCenters; /* Calculate total size */ - Size totalSize = samplesSize + centersSize + newCentersSize + aggSize + centerCountsSize + closestCentersSize + lowerBoundSize + upperBoundSize + sSize + halfcdistSize + newcdistSize; + Size totalSize = memoryUsed + newCentersSize + aggSize + centerCountsSize + closestCentersSize + lowerBoundSize + upperBoundSize + sSize + halfcdistSize + newcdistSize; /* Check memory requirements */ /* Add one to error message to ceil */ @@ -548,7 +546,7 @@ CheckCenters(Relation index, VectorArray centers, const IvfflatTypeInfo * typeIn * We use spherical k-means for inner product and cosine */ void -IvfflatKmeans(Relation index, VectorArray samples, VectorArray centers, const IvfflatTypeInfo * typeInfo) +IvfflatKmeans(Relation index, VectorArray samples, VectorArray centers, const IvfflatTypeInfo * typeInfo, Size memoryUsed) { MemoryContext kmeansCtx = AllocSetContextCreate(CurrentMemoryContext, "Ivfflat kmeans temporary context", @@ -558,7 +556,7 @@ IvfflatKmeans(Relation index, VectorArray samples, VectorArray centers, const Iv if (samples->length == 0) RandomCenters(index, centers, typeInfo); else - ElkanKmeans(index, samples, centers, typeInfo); + ElkanKmeans(index, samples, centers, typeInfo, memoryUsed); CheckCenters(index, centers, typeInfo);