diff --git a/src/ivfbuild.c b/src/ivfbuild.c index f067cf7..b5e6da3 100644 --- a/src/ivfbuild.c +++ b/src/ivfbuild.c @@ -152,18 +152,7 @@ SampleRows(IvfflatBuildState * buildstate) /* Normalize if needed */ if (buildstate->kmeansnormprocinfo != NULL) - { - VectorArray samples = buildstate->samples; - - for (int i = 0; i < samples->length; i++) - { - Datum value = PointerGetDatum(VectorArrayGet(samples, i)); - Datum normValue = IvfflatNormValue(buildstate->typeInfo, buildstate->collation, value); - - VectorArraySet(samples, i, DatumGetPointer(normValue)); - pfree(DatumGetPointer(normValue)); - } - } + IvfflatNormVectors(buildstate->typeInfo, buildstate->collation, buildstate->samples, buildstate->tmpCtx); } /* diff --git a/src/ivfflat.h b/src/ivfflat.h index e0c06b7..87f6328 100644 --- a/src/ivfflat.h +++ b/src/ivfflat.h @@ -324,6 +324,7 @@ void IvfflatKmeans(Relation index, VectorArray samples, VectorArray centers, co FmgrInfo *IvfflatOptionalProcInfo(Relation index, uint16 procnum); Datum IvfflatNormValue(const IvfflatTypeInfo * typeInfo, Oid collation, Datum value); bool IvfflatCheckNorm(FmgrInfo *procinfo, Oid collation, Datum value); +void IvfflatNormVectors(const IvfflatTypeInfo * typeInfo, Oid collation, VectorArray arr, MemoryContext tmpCtx); int IvfflatGetLists(Relation index); void IvfflatGetMetaPageInfo(Relation index, int *lists, int *dimensions); void IvfflatUpdateList(Relation index, ListInfo listInfo, BlockNumber insertPage, BlockNumber originalInsertPage, BlockNumber startPage, ForkNumber forkNum); diff --git a/src/ivfkmeans.c b/src/ivfkmeans.c index f546f9e..0107e13 100644 --- a/src/ivfkmeans.c +++ b/src/ivfkmeans.c @@ -99,18 +99,8 @@ NormCenters(const IvfflatTypeInfo * typeInfo, Oid collation, VectorArray centers MemoryContext normCtx = AllocSetContextCreate(CurrentMemoryContext, "Ivfflat norm temporary context", ALLOCSET_DEFAULT_SIZES); - MemoryContext oldCtx = MemoryContextSwitchTo(normCtx); - for (int j = 0; j < centers->length; j++) - { - Datum center = PointerGetDatum(VectorArrayGet(centers, j)); - Datum newCenter = IvfflatNormValue(typeInfo, collation, center); - - VectorArraySet(centers, j, DatumGetPointer(newCenter)); - MemoryContextReset(normCtx); - } - - MemoryContextSwitchTo(oldCtx); + IvfflatNormVectors(typeInfo, collation, centers, normCtx); MemoryContextDelete(normCtx); } diff --git a/src/ivfutils.c b/src/ivfutils.c index f1586dd..b7ec566 100644 --- a/src/ivfutils.c +++ b/src/ivfutils.c @@ -7,6 +7,7 @@ #include "halfvec.h" #include "ivfflat.h" #include "storage/bufmgr.h" +#include "utils/memutils.h" #include "utils/relcache.h" #include "utils/varbit.h" #include "vector.h" @@ -88,6 +89,26 @@ IvfflatCheckNorm(FmgrInfo *procinfo, Oid collation, Datum value) return DatumGetFloat8(FunctionCall1Coll(procinfo, collation, value)) > 0; } +/* + * Normalize vectors + */ +void +IvfflatNormVectors(const IvfflatTypeInfo * typeInfo, Oid collation, VectorArray arr, MemoryContext tmpCtx) +{ + MemoryContext oldCtx = MemoryContextSwitchTo(tmpCtx); + + for (int i = 0; i < arr->length; i++) + { + Datum value = PointerGetDatum(VectorArrayGet(arr, i)); + Datum newValue = IvfflatNormValue(typeInfo, collation, value); + + VectorArraySet(arr, i, DatumGetPointer(newValue)); + MemoryContextReset(tmpCtx); + } + + MemoryContextSwitchTo(oldCtx); +} + /* * New buffer */