mirror of
https://github.com/pgvector/pgvector.git
synced 2026-07-03 03:00:56 +08:00
DRY normalize code for IVFFlat index builds
This commit is contained in:
@@ -152,18 +152,7 @@ SampleRows(IvfflatBuildState * buildstate)
|
||||
|
||||
/* Normalize if needed */
|
||||
if (buildstate->kmeansnormprocinfo != NULL)
|
||||
{
|
||||
VectorArray samples = buildstate->samples;
|
||||
|
||||
for (int i = 0; i < samples->length; i++)
|
||||
{
|
||||
Datum value = PointerGetDatum(VectorArrayGet(samples, i));
|
||||
Datum normValue = IvfflatNormValue(buildstate->typeInfo, buildstate->collation, value);
|
||||
|
||||
VectorArraySet(samples, i, DatumGetPointer(normValue));
|
||||
pfree(DatumGetPointer(normValue));
|
||||
}
|
||||
}
|
||||
IvfflatNormVectors(buildstate->typeInfo, buildstate->collation, buildstate->samples, buildstate->tmpCtx);
|
||||
}
|
||||
|
||||
/*
|
||||
|
||||
@@ -324,6 +324,7 @@ void IvfflatKmeans(Relation index, VectorArray samples, VectorArray centers, co
|
||||
FmgrInfo *IvfflatOptionalProcInfo(Relation index, uint16 procnum);
|
||||
Datum IvfflatNormValue(const IvfflatTypeInfo * typeInfo, Oid collation, Datum value);
|
||||
bool IvfflatCheckNorm(FmgrInfo *procinfo, Oid collation, Datum value);
|
||||
void IvfflatNormVectors(const IvfflatTypeInfo * typeInfo, Oid collation, VectorArray arr, MemoryContext tmpCtx);
|
||||
int IvfflatGetLists(Relation index);
|
||||
void IvfflatGetMetaPageInfo(Relation index, int *lists, int *dimensions);
|
||||
void IvfflatUpdateList(Relation index, ListInfo listInfo, BlockNumber insertPage, BlockNumber originalInsertPage, BlockNumber startPage, ForkNumber forkNum);
|
||||
|
||||
@@ -99,18 +99,8 @@ NormCenters(const IvfflatTypeInfo * typeInfo, Oid collation, VectorArray centers
|
||||
MemoryContext normCtx = AllocSetContextCreate(CurrentMemoryContext,
|
||||
"Ivfflat norm temporary context",
|
||||
ALLOCSET_DEFAULT_SIZES);
|
||||
MemoryContext oldCtx = MemoryContextSwitchTo(normCtx);
|
||||
|
||||
for (int j = 0; j < centers->length; j++)
|
||||
{
|
||||
Datum center = PointerGetDatum(VectorArrayGet(centers, j));
|
||||
Datum newCenter = IvfflatNormValue(typeInfo, collation, center);
|
||||
|
||||
VectorArraySet(centers, j, DatumGetPointer(newCenter));
|
||||
MemoryContextReset(normCtx);
|
||||
}
|
||||
|
||||
MemoryContextSwitchTo(oldCtx);
|
||||
IvfflatNormVectors(typeInfo, collation, centers, normCtx);
|
||||
MemoryContextDelete(normCtx);
|
||||
}
|
||||
|
||||
|
||||
@@ -7,6 +7,7 @@
|
||||
#include "halfvec.h"
|
||||
#include "ivfflat.h"
|
||||
#include "storage/bufmgr.h"
|
||||
#include "utils/memutils.h"
|
||||
#include "utils/relcache.h"
|
||||
#include "utils/varbit.h"
|
||||
#include "vector.h"
|
||||
@@ -88,6 +89,26 @@ IvfflatCheckNorm(FmgrInfo *procinfo, Oid collation, Datum value)
|
||||
return DatumGetFloat8(FunctionCall1Coll(procinfo, collation, value)) > 0;
|
||||
}
|
||||
|
||||
/*
|
||||
* Normalize vectors
|
||||
*/
|
||||
void
|
||||
IvfflatNormVectors(const IvfflatTypeInfo * typeInfo, Oid collation, VectorArray arr, MemoryContext tmpCtx)
|
||||
{
|
||||
MemoryContext oldCtx = MemoryContextSwitchTo(tmpCtx);
|
||||
|
||||
for (int i = 0; i < arr->length; i++)
|
||||
{
|
||||
Datum value = PointerGetDatum(VectorArrayGet(arr, i));
|
||||
Datum newValue = IvfflatNormValue(typeInfo, collation, value);
|
||||
|
||||
VectorArraySet(arr, i, DatumGetPointer(newValue));
|
||||
MemoryContextReset(tmpCtx);
|
||||
}
|
||||
|
||||
MemoryContextSwitchTo(oldCtx);
|
||||
}
|
||||
|
||||
/*
|
||||
* New buffer
|
||||
*/
|
||||
|
||||
Reference in New Issue
Block a user