mirror of
https://github.com/pgvector/pgvector.git
synced 2026-06-06 05:51:21 +08:00
Switched to static const for IVFFlat type info
This commit is contained in:
@@ -325,18 +325,14 @@ InsertTuples(Relation index, IvfflatBuildState * buildstate, ForkNumber forkNum)
|
||||
static void
|
||||
InitBuildState(IvfflatBuildState * buildstate, Relation heap, Relation index, IndexInfo *indexInfo)
|
||||
{
|
||||
IvfflatTypeInfo *typeInfo = &buildstate->typeInfo;
|
||||
|
||||
buildstate->heap = heap;
|
||||
buildstate->index = index;
|
||||
buildstate->indexInfo = indexInfo;
|
||||
buildstate->typeInfo = IvfflatGetTypeInfo(index);
|
||||
|
||||
buildstate->lists = IvfflatGetLists(index);
|
||||
buildstate->dimensions = TupleDescAttr(index->rd_att, 0)->atttypmod;
|
||||
|
||||
typeInfo->dimensions = buildstate->dimensions;
|
||||
IvfflatGetTypeInfo(typeInfo, index);
|
||||
|
||||
/* Disallow varbit since require fixed dimensions */
|
||||
if (TupleDescAttr(index->rd_att, 0)->atttypid == VARBITOID)
|
||||
elog(ERROR, "type not supported for ivfflat index");
|
||||
@@ -345,8 +341,8 @@ InitBuildState(IvfflatBuildState * buildstate, Relation heap, Relation index, In
|
||||
if (buildstate->dimensions < 0)
|
||||
elog(ERROR, "column does not have dimensions");
|
||||
|
||||
if (buildstate->dimensions > typeInfo->maxDimensions)
|
||||
elog(ERROR, "column cannot have more than %d dimensions for ivfflat index", typeInfo->maxDimensions);
|
||||
if (buildstate->dimensions > buildstate->typeInfo->maxDimensions)
|
||||
elog(ERROR, "column cannot have more than %d dimensions for ivfflat index", buildstate->typeInfo->maxDimensions);
|
||||
|
||||
buildstate->reltuples = 0;
|
||||
buildstate->indtuples = 0;
|
||||
@@ -370,7 +366,8 @@ InitBuildState(IvfflatBuildState * buildstate, Relation heap, Relation index, In
|
||||
|
||||
buildstate->slot = MakeSingleTupleTableSlot(buildstate->tupdesc, &TTSOpsVirtual);
|
||||
|
||||
buildstate->centers = VectorArrayInit(buildstate->lists, buildstate->dimensions, typeInfo->itemsize);
|
||||
/* TODO Fix item size */
|
||||
buildstate->centers = VectorArrayInit(buildstate->lists, buildstate->dimensions, VECTOR_SIZE(buildstate->dimensions));
|
||||
buildstate->listInfo = palloc(sizeof(ListInfo) * buildstate->lists);
|
||||
|
||||
buildstate->tmpCtx = AllocSetContextCreate(CurrentMemoryContext,
|
||||
@@ -440,7 +437,7 @@ ComputeCenters(IvfflatBuildState * buildstate)
|
||||
}
|
||||
|
||||
/* Calculate centers */
|
||||
IvfflatBench("k-means", IvfflatKmeans(buildstate->index, buildstate->samples, buildstate->centers, &buildstate->typeInfo));
|
||||
IvfflatBench("k-means", IvfflatKmeans(buildstate->index, buildstate->samples, buildstate->centers, buildstate->typeInfo));
|
||||
|
||||
/* Free samples before we allocate more memory */
|
||||
VectorArrayFree(buildstate->samples);
|
||||
|
||||
@@ -165,7 +165,7 @@ typedef struct IvfflatBuildState
|
||||
Relation heap;
|
||||
Relation index;
|
||||
IndexInfo *indexInfo;
|
||||
IvfflatTypeInfo typeInfo;
|
||||
const IvfflatTypeInfo *typeInfo;
|
||||
|
||||
/* Settings */
|
||||
int dimensions;
|
||||
@@ -279,7 +279,7 @@ typedef IvfflatScanOpaqueData * IvfflatScanOpaque;
|
||||
/* Methods */
|
||||
VectorArray VectorArrayInit(int maxlen, int dimensions, Size itemsize);
|
||||
void VectorArrayFree(VectorArray arr);
|
||||
void IvfflatKmeans(Relation index, VectorArray samples, VectorArray centers, IvfflatTypeInfo * typeInfo);
|
||||
void IvfflatKmeans(Relation index, VectorArray samples, VectorArray centers, const IvfflatTypeInfo * typeInfo);
|
||||
FmgrInfo *IvfflatOptionalProcInfo(Relation index, uint16 procnum);
|
||||
Datum IvfflatNormValue(FmgrInfo *procinfo, Oid collation, Datum value);
|
||||
bool IvfflatCheckNorm(FmgrInfo *procinfo, Oid collation, Datum value);
|
||||
@@ -292,7 +292,7 @@ Buffer IvfflatNewBuffer(Relation index, ForkNumber forkNum);
|
||||
void IvfflatInitPage(Buffer buf, Page page);
|
||||
void IvfflatInitRegisterPage(Relation index, Buffer *buf, Page *page, GenericXLogState **state);
|
||||
void IvfflatInit(void);
|
||||
void IvfflatGetTypeInfo(IvfflatTypeInfo * typeInfo, Relation index);
|
||||
const IvfflatTypeInfo *IvfflatGetTypeInfo(Relation index);
|
||||
PGDLLEXPORT void IvfflatParallelBuildMain(dsm_segment *seg, shm_toc *toc);
|
||||
|
||||
/* Index access methods */
|
||||
|
||||
@@ -120,7 +120,7 @@ NormCenters(FmgrInfo *normalizeprocinfo, Oid collation, VectorArray centers)
|
||||
* Quick approach if we have no data
|
||||
*/
|
||||
static void
|
||||
RandomCenters(Relation index, VectorArray centers, IvfflatTypeInfo * typeInfo)
|
||||
RandomCenters(Relation index, VectorArray centers, const IvfflatTypeInfo * typeInfo)
|
||||
{
|
||||
int dimensions = centers->dim;
|
||||
Oid collation = index->rd_indcollation[0];
|
||||
@@ -168,7 +168,7 @@ ShowMemoryUsage(MemoryContext context, Size estimatedSize)
|
||||
* Sum centers
|
||||
*/
|
||||
static void
|
||||
SumCenters(VectorArray samples, float *agg, int *closestCenters, IvfflatTypeInfo * typeInfo)
|
||||
SumCenters(VectorArray samples, float *agg, int *closestCenters, const IvfflatTypeInfo * typeInfo)
|
||||
{
|
||||
for (int j = 0; j < samples->length; j++)
|
||||
{
|
||||
@@ -182,7 +182,7 @@ SumCenters(VectorArray samples, float *agg, int *closestCenters, IvfflatTypeInfo
|
||||
* Update centers
|
||||
*/
|
||||
static void
|
||||
UpdateCenters(float *agg, VectorArray centers, IvfflatTypeInfo * typeInfo)
|
||||
UpdateCenters(float *agg, VectorArray centers, const IvfflatTypeInfo * typeInfo)
|
||||
{
|
||||
for (int j = 0; j < centers->length; j++)
|
||||
{
|
||||
@@ -196,7 +196,7 @@ UpdateCenters(float *agg, VectorArray centers, IvfflatTypeInfo * typeInfo)
|
||||
* Compute new centers
|
||||
*/
|
||||
static void
|
||||
ComputeNewCenters(VectorArray samples, float *agg, VectorArray newCenters, int *centerCounts, int *closestCenters, FmgrInfo *normprocinfo, FmgrInfo *normalizeprocinfo, Oid collation, IvfflatTypeInfo * typeInfo)
|
||||
ComputeNewCenters(VectorArray samples, float *agg, VectorArray newCenters, int *centerCounts, int *closestCenters, FmgrInfo *normprocinfo, FmgrInfo *normalizeprocinfo, Oid collation, const IvfflatTypeInfo * typeInfo)
|
||||
{
|
||||
int dimensions = newCenters->dim;
|
||||
int numCenters = newCenters->length;
|
||||
@@ -263,7 +263,7 @@ ComputeNewCenters(VectorArray samples, float *agg, VectorArray newCenters, int *
|
||||
* https://www.aaai.org/Papers/ICML/2003/ICML03-022.pdf
|
||||
*/
|
||||
static void
|
||||
ElkanKmeans(Relation index, VectorArray samples, VectorArray centers, IvfflatTypeInfo * typeInfo)
|
||||
ElkanKmeans(Relation index, VectorArray samples, VectorArray centers, const IvfflatTypeInfo * typeInfo)
|
||||
{
|
||||
FmgrInfo *procinfo;
|
||||
FmgrInfo *normprocinfo;
|
||||
@@ -517,7 +517,7 @@ ElkanKmeans(Relation index, VectorArray samples, VectorArray centers, IvfflatTyp
|
||||
* Detect issues with centers
|
||||
*/
|
||||
static void
|
||||
CheckCenters(Relation index, VectorArray centers, IvfflatTypeInfo * typeInfo)
|
||||
CheckCenters(Relation index, VectorArray centers, const IvfflatTypeInfo * typeInfo)
|
||||
{
|
||||
FmgrInfo *normprocinfo;
|
||||
float *scratch = palloc(sizeof(float) * centers->dim);
|
||||
@@ -568,7 +568,7 @@ CheckCenters(Relation index, VectorArray centers, IvfflatTypeInfo * typeInfo)
|
||||
* We use spherical k-means for inner product and cosine
|
||||
*/
|
||||
void
|
||||
IvfflatKmeans(Relation index, VectorArray samples, VectorArray centers, IvfflatTypeInfo * typeInfo)
|
||||
IvfflatKmeans(Relation index, VectorArray samples, VectorArray centers, const IvfflatTypeInfo * typeInfo)
|
||||
{
|
||||
if (samples->length == 0)
|
||||
RandomCenters(index, centers, typeInfo);
|
||||
|
||||
@@ -298,46 +298,47 @@ BitSumCenter(Pointer v, float *x)
|
||||
/*
|
||||
* Get type info
|
||||
*/
|
||||
void
|
||||
IvfflatGetTypeInfo(IvfflatTypeInfo * typeInfo, Relation index)
|
||||
const IvfflatTypeInfo *
|
||||
IvfflatGetTypeInfo(Relation index)
|
||||
{
|
||||
FmgrInfo *procinfo = IvfflatOptionalProcInfo(index, IVFFLAT_TYPE_INFO_PROC);
|
||||
|
||||
if (procinfo == NULL)
|
||||
{
|
||||
typeInfo->maxDimensions = IVFFLAT_MAX_DIM;
|
||||
typeInfo->itemsize = VECTOR_SIZE(typeInfo->dimensions);
|
||||
typeInfo->updateCenter = VectorUpdateCenter;
|
||||
typeInfo->sumCenter = VectorSumCenter;
|
||||
static const IvfflatTypeInfo typeInfo = {
|
||||
.maxDimensions = IVFFLAT_MAX_DIM,
|
||||
.updateCenter = VectorUpdateCenter,
|
||||
.sumCenter = VectorSumCenter
|
||||
};
|
||||
|
||||
return (&typeInfo);
|
||||
}
|
||||
else
|
||||
FunctionCall1(procinfo, PointerGetDatum(typeInfo));
|
||||
return (const IvfflatTypeInfo *) DatumGetPointer(FunctionCall0Coll(procinfo, InvalidOid));
|
||||
}
|
||||
|
||||
PGDLLEXPORT PG_FUNCTION_INFO_V1(ivfflat_halfvec_support);
|
||||
Datum
|
||||
ivfflat_halfvec_support(PG_FUNCTION_ARGS)
|
||||
{
|
||||
IvfflatTypeInfo *typeInfo = (IvfflatTypeInfo *) PG_GETARG_POINTER(0);
|
||||
static const IvfflatTypeInfo typeInfo = {
|
||||
.maxDimensions = IVFFLAT_MAX_DIM * 2,
|
||||
.updateCenter = HalfvecUpdateCenter,
|
||||
.sumCenter = HalfvecSumCenter
|
||||
};
|
||||
|
||||
typeInfo->maxDimensions = IVFFLAT_MAX_DIM * 2;
|
||||
typeInfo->itemsize = HALFVEC_SIZE(typeInfo->dimensions);
|
||||
typeInfo->updateCenter = HalfvecUpdateCenter;
|
||||
typeInfo->sumCenter = HalfvecSumCenter;
|
||||
|
||||
PG_RETURN_VOID();
|
||||
PG_RETURN_POINTER(&typeInfo);
|
||||
};
|
||||
|
||||
PGDLLEXPORT PG_FUNCTION_INFO_V1(ivfflat_bit_support);
|
||||
Datum
|
||||
ivfflat_bit_support(PG_FUNCTION_ARGS)
|
||||
{
|
||||
IvfflatTypeInfo *typeInfo = (IvfflatTypeInfo *) PG_GETARG_POINTER(0);
|
||||
static const IvfflatTypeInfo typeInfo = {
|
||||
.maxDimensions = IVFFLAT_MAX_DIM * 32,
|
||||
.updateCenter = BitUpdateCenter,
|
||||
.sumCenter = BitSumCenter
|
||||
};
|
||||
|
||||
typeInfo->maxDimensions = IVFFLAT_MAX_DIM * 32;
|
||||
typeInfo->itemsize = VARBITTOTALLEN(typeInfo->dimensions);
|
||||
typeInfo->updateCenter = BitUpdateCenter;
|
||||
typeInfo->sumCenter = BitSumCenter;
|
||||
|
||||
PG_RETURN_VOID();
|
||||
PG_RETURN_POINTER(&typeInfo);
|
||||
};
|
||||
|
||||
Reference in New Issue
Block a user