mirror of
https://github.com/pgvector/pgvector.git
synced 2026-06-06 05:51:21 +08:00
Reduced support functions for IVFFlat - #527
This commit is contained in:
@@ -63,7 +63,7 @@ CREATE OPERATOR CLASS bit_hamming_ops
|
||||
OPERATOR 1 <~> (bit, bit) FOR ORDER BY float_ops,
|
||||
FUNCTION 1 hamming_distance(bit, bit),
|
||||
FUNCTION 3 hamming_distance(bit, bit),
|
||||
FUNCTION 6 ivfflat_bit_support(internal);
|
||||
FUNCTION 5 ivfflat_bit_support(internal);
|
||||
|
||||
CREATE OPERATOR CLASS bit_hamming_ops
|
||||
FOR TYPE bit USING hnsw AS
|
||||
@@ -330,7 +330,7 @@ CREATE OPERATOR CLASS halfvec_l2_ops
|
||||
OPERATOR 1 <-> (halfvec, halfvec) FOR ORDER BY float_ops,
|
||||
FUNCTION 1 halfvec_l2_squared_distance(halfvec, halfvec),
|
||||
FUNCTION 3 l2_distance(halfvec, halfvec),
|
||||
FUNCTION 6 ivfflat_halfvec_support(internal);
|
||||
FUNCTION 5 ivfflat_halfvec_support(internal);
|
||||
|
||||
CREATE OPERATOR CLASS halfvec_ip_ops
|
||||
FOR TYPE halfvec USING ivfflat AS
|
||||
@@ -338,8 +338,7 @@ CREATE OPERATOR CLASS halfvec_ip_ops
|
||||
FUNCTION 1 halfvec_negative_inner_product(halfvec, halfvec),
|
||||
FUNCTION 3 halfvec_spherical_distance(halfvec, halfvec),
|
||||
FUNCTION 4 l2_norm(halfvec),
|
||||
FUNCTION 5 l2_normalize(halfvec),
|
||||
FUNCTION 6 ivfflat_halfvec_support(internal);
|
||||
FUNCTION 5 ivfflat_halfvec_support(internal);
|
||||
|
||||
CREATE OPERATOR CLASS halfvec_cosine_ops
|
||||
FOR TYPE halfvec USING ivfflat AS
|
||||
@@ -348,8 +347,7 @@ CREATE OPERATOR CLASS halfvec_cosine_ops
|
||||
FUNCTION 2 l2_norm(halfvec),
|
||||
FUNCTION 3 halfvec_spherical_distance(halfvec, halfvec),
|
||||
FUNCTION 4 l2_norm(halfvec),
|
||||
FUNCTION 5 l2_normalize(halfvec),
|
||||
FUNCTION 6 ivfflat_halfvec_support(internal);
|
||||
FUNCTION 5 ivfflat_halfvec_support(internal);
|
||||
|
||||
CREATE OPERATOR CLASS halfvec_l2_ops
|
||||
FOR TYPE halfvec USING hnsw AS
|
||||
|
||||
@@ -358,7 +358,7 @@ CREATE OPERATOR CLASS bit_hamming_ops
|
||||
OPERATOR 1 <~> (bit, bit) FOR ORDER BY float_ops,
|
||||
FUNCTION 1 hamming_distance(bit, bit),
|
||||
FUNCTION 3 hamming_distance(bit, bit),
|
||||
FUNCTION 6 ivfflat_bit_support(internal);
|
||||
FUNCTION 5 ivfflat_bit_support(internal);
|
||||
|
||||
CREATE OPERATOR CLASS bit_hamming_ops
|
||||
FOR TYPE bit USING hnsw AS
|
||||
@@ -641,7 +641,7 @@ CREATE OPERATOR CLASS halfvec_l2_ops
|
||||
OPERATOR 1 <-> (halfvec, halfvec) FOR ORDER BY float_ops,
|
||||
FUNCTION 1 halfvec_l2_squared_distance(halfvec, halfvec),
|
||||
FUNCTION 3 l2_distance(halfvec, halfvec),
|
||||
FUNCTION 6 ivfflat_halfvec_support(internal);
|
||||
FUNCTION 5 ivfflat_halfvec_support(internal);
|
||||
|
||||
CREATE OPERATOR CLASS halfvec_ip_ops
|
||||
FOR TYPE halfvec USING ivfflat AS
|
||||
@@ -649,8 +649,7 @@ CREATE OPERATOR CLASS halfvec_ip_ops
|
||||
FUNCTION 1 halfvec_negative_inner_product(halfvec, halfvec),
|
||||
FUNCTION 3 halfvec_spherical_distance(halfvec, halfvec),
|
||||
FUNCTION 4 l2_norm(halfvec),
|
||||
FUNCTION 5 l2_normalize(halfvec),
|
||||
FUNCTION 6 ivfflat_halfvec_support(internal);
|
||||
FUNCTION 5 ivfflat_halfvec_support(internal);
|
||||
|
||||
CREATE OPERATOR CLASS halfvec_cosine_ops
|
||||
FOR TYPE halfvec USING ivfflat AS
|
||||
@@ -659,8 +658,7 @@ CREATE OPERATOR CLASS halfvec_cosine_ops
|
||||
FUNCTION 2 l2_norm(halfvec),
|
||||
FUNCTION 3 halfvec_spherical_distance(halfvec, halfvec),
|
||||
FUNCTION 4 l2_norm(halfvec),
|
||||
FUNCTION 5 l2_normalize(halfvec),
|
||||
FUNCTION 6 ivfflat_halfvec_support(internal);
|
||||
FUNCTION 5 ivfflat_halfvec_support(internal);
|
||||
|
||||
CREATE OPERATOR CLASS halfvec_l2_ops
|
||||
FOR TYPE halfvec USING hnsw AS
|
||||
|
||||
@@ -63,7 +63,7 @@ AddSample(Datum *values, IvfflatBuildState * buildstate)
|
||||
if (!IvfflatCheckNorm(buildstate->kmeansnormprocinfo, buildstate->collation, value))
|
||||
return;
|
||||
|
||||
value = IvfflatNormValue(buildstate->normalizeprocinfo, buildstate->collation, value);
|
||||
value = IvfflatNormValue(buildstate->typeInfo, buildstate->collation, value);
|
||||
}
|
||||
|
||||
if (samples->length < targsamples)
|
||||
@@ -161,7 +161,7 @@ AddTupleToSort(Relation index, ItemPointer tid, Datum *values, IvfflatBuildState
|
||||
if (!IvfflatCheckNorm(buildstate->normprocinfo, buildstate->collation, value))
|
||||
return;
|
||||
|
||||
value = IvfflatNormValue(buildstate->normalizeprocinfo, buildstate->collation, value);
|
||||
value = IvfflatNormValue(buildstate->typeInfo, buildstate->collation, value);
|
||||
}
|
||||
|
||||
/* Find the list that minimizes the distance */
|
||||
@@ -351,7 +351,6 @@ InitBuildState(IvfflatBuildState * buildstate, Relation heap, Relation index, In
|
||||
buildstate->procinfo = index_getprocinfo(index, 1, IVFFLAT_DISTANCE_PROC);
|
||||
buildstate->normprocinfo = IvfflatOptionalProcInfo(index, IVFFLAT_NORM_PROC);
|
||||
buildstate->kmeansnormprocinfo = IvfflatOptionalProcInfo(index, IVFFLAT_KMEANS_NORM_PROC);
|
||||
buildstate->normalizeprocinfo = IvfflatOptionalProcInfo(index, IVFFLAT_NORMALIZE_PROC);
|
||||
buildstate->collation = index->rd_indcollation[0];
|
||||
|
||||
/* Require more than one dimension for spherical k-means */
|
||||
|
||||
@@ -188,7 +188,7 @@ ivfflathandler(PG_FUNCTION_ARGS)
|
||||
IndexAmRoutine *amroutine = makeNode(IndexAmRoutine);
|
||||
|
||||
amroutine->amstrategies = 0;
|
||||
amroutine->amsupport = 6;
|
||||
amroutine->amsupport = 5;
|
||||
#if PG_VERSION_NUM >= 130000
|
||||
amroutine->amoptsprocnum = 0;
|
||||
#endif
|
||||
|
||||
@@ -28,8 +28,7 @@
|
||||
#define IVFFLAT_NORM_PROC 2
|
||||
#define IVFFLAT_KMEANS_DISTANCE_PROC 3
|
||||
#define IVFFLAT_KMEANS_NORM_PROC 4
|
||||
#define IVFFLAT_NORMALIZE_PROC 5
|
||||
#define IVFFLAT_TYPE_INFO_PROC 6
|
||||
#define IVFFLAT_TYPE_INFO_PROC 5
|
||||
|
||||
#define IVFFLAT_VERSION 1
|
||||
#define IVFFLAT_MAGIC_NUMBER 0x14FF1A7
|
||||
@@ -153,6 +152,7 @@ typedef struct IvfflatLeader
|
||||
typedef struct IvfflatTypeInfo
|
||||
{
|
||||
int maxDimensions;
|
||||
Datum (*normalize) (PG_FUNCTION_ARGS);
|
||||
void (*updateCenter) (Pointer v, int dimensions, float *x);
|
||||
void (*sumCenter) (Pointer v, float *x);
|
||||
} IvfflatTypeInfo;
|
||||
@@ -177,7 +177,6 @@ typedef struct IvfflatBuildState
|
||||
FmgrInfo *procinfo;
|
||||
FmgrInfo *normprocinfo;
|
||||
FmgrInfo *kmeansnormprocinfo;
|
||||
FmgrInfo *normalizeprocinfo;
|
||||
Oid collation;
|
||||
|
||||
/* Variables */
|
||||
@@ -245,6 +244,7 @@ typedef struct IvfflatScanList
|
||||
|
||||
typedef struct IvfflatScanOpaqueData
|
||||
{
|
||||
const IvfflatTypeInfo *typeInfo;
|
||||
int probes;
|
||||
int dimensions;
|
||||
bool first;
|
||||
@@ -258,7 +258,6 @@ typedef struct IvfflatScanOpaqueData
|
||||
/* Support functions */
|
||||
FmgrInfo *procinfo;
|
||||
FmgrInfo *normprocinfo;
|
||||
FmgrInfo *normalizeprocinfo;
|
||||
Oid collation;
|
||||
Datum (*distfunc) (FmgrInfo *flinfo, Oid collation, Datum arg1, Datum arg2);
|
||||
|
||||
@@ -279,7 +278,7 @@ VectorArray VectorArrayInit(int maxlen, int dimensions, Size itemsize);
|
||||
void VectorArrayFree(VectorArray arr);
|
||||
void IvfflatKmeans(Relation index, VectorArray samples, VectorArray centers, const IvfflatTypeInfo * typeInfo);
|
||||
FmgrInfo *IvfflatOptionalProcInfo(Relation index, uint16 procnum);
|
||||
Datum IvfflatNormValue(FmgrInfo *procinfo, Oid collation, Datum value);
|
||||
Datum IvfflatNormValue(const IvfflatTypeInfo * typeInfo, Oid collation, Datum value);
|
||||
bool IvfflatCheckNorm(FmgrInfo *procinfo, Oid collation, Datum value);
|
||||
int IvfflatGetLists(Relation index);
|
||||
void IvfflatGetMetaPageInfo(Relation index, int *lists, int *dimensions);
|
||||
|
||||
@@ -67,6 +67,7 @@ FindInsertPage(Relation index, Datum *values, BlockNumber *insertPage, ListInfo
|
||||
static void
|
||||
InsertTuple(Relation index, Datum *values, bool *isnull, ItemPointer heap_tid, Relation heapRel)
|
||||
{
|
||||
const IvfflatTypeInfo *typeInfo = IvfflatGetTypeInfo(index);
|
||||
IndexTuple itup;
|
||||
Datum value;
|
||||
FmgrInfo *normprocinfo;
|
||||
@@ -90,7 +91,7 @@ InsertTuple(Relation index, Datum *values, bool *isnull, ItemPointer heap_tid, R
|
||||
if (!IvfflatCheckNorm(normprocinfo, collation, value))
|
||||
return;
|
||||
|
||||
value = IvfflatNormValue(IvfflatOptionalProcInfo(index, IVFFLAT_NORMALIZE_PROC), collation, value);
|
||||
value = IvfflatNormValue(typeInfo, collation, value);
|
||||
}
|
||||
|
||||
/* Find the insert page - sets the page and list info */
|
||||
|
||||
@@ -92,7 +92,7 @@ InitCenters(Relation index, VectorArray samples, VectorArray centers, float *low
|
||||
* Norm centers
|
||||
*/
|
||||
static void
|
||||
NormCenters(FmgrInfo *normalizeprocinfo, Oid collation, VectorArray centers)
|
||||
NormCenters(const IvfflatTypeInfo * typeInfo, Oid collation, VectorArray centers)
|
||||
{
|
||||
MemoryContext normCtx = AllocSetContextCreate(CurrentMemoryContext,
|
||||
"Ivfflat norm temporary context",
|
||||
@@ -102,7 +102,7 @@ NormCenters(FmgrInfo *normalizeprocinfo, Oid collation, VectorArray centers)
|
||||
for (int j = 0; j < centers->length; j++)
|
||||
{
|
||||
Datum center = PointerGetDatum(VectorArrayGet(centers, j));
|
||||
Datum newCenter = IvfflatNormValue(normalizeprocinfo, collation, center);
|
||||
Datum newCenter = IvfflatNormValue(typeInfo, collation, center);
|
||||
Size size = VARSIZE_ANY(DatumGetPointer(newCenter));
|
||||
|
||||
if (size > centers->itemsize)
|
||||
@@ -123,9 +123,8 @@ static void
|
||||
RandomCenters(Relation index, VectorArray centers, const IvfflatTypeInfo * typeInfo)
|
||||
{
|
||||
int dimensions = centers->dim;
|
||||
Oid collation = index->rd_indcollation[0];
|
||||
FmgrInfo *normprocinfo = IvfflatOptionalProcInfo(index, IVFFLAT_KMEANS_NORM_PROC);
|
||||
FmgrInfo *normalizeprocinfo = IvfflatOptionalProcInfo(index, IVFFLAT_NORMALIZE_PROC);
|
||||
Oid collation = index->rd_indcollation[0];
|
||||
float *x = (float *) palloc(sizeof(float) * dimensions);
|
||||
|
||||
/* Fill with random data */
|
||||
@@ -142,7 +141,7 @@ RandomCenters(Relation index, VectorArray centers, const IvfflatTypeInfo * typeI
|
||||
}
|
||||
|
||||
if (normprocinfo != NULL)
|
||||
NormCenters(normalizeprocinfo, collation, centers);
|
||||
NormCenters(typeInfo, collation, centers);
|
||||
|
||||
pfree(x);
|
||||
}
|
||||
@@ -196,7 +195,7 @@ UpdateCenters(float *agg, VectorArray centers, const IvfflatTypeInfo * typeInfo)
|
||||
* Compute new centers
|
||||
*/
|
||||
static void
|
||||
ComputeNewCenters(VectorArray samples, float *agg, VectorArray newCenters, int *centerCounts, int *closestCenters, FmgrInfo *normprocinfo, FmgrInfo *normalizeprocinfo, Oid collation, const IvfflatTypeInfo * typeInfo)
|
||||
ComputeNewCenters(VectorArray samples, float *agg, VectorArray newCenters, int *centerCounts, int *closestCenters, FmgrInfo *normprocinfo, Oid collation, const IvfflatTypeInfo * typeInfo)
|
||||
{
|
||||
int dimensions = newCenters->dim;
|
||||
int numCenters = newCenters->length;
|
||||
@@ -251,7 +250,7 @@ ComputeNewCenters(VectorArray samples, float *agg, VectorArray newCenters, int *
|
||||
|
||||
/* Normalize if needed */
|
||||
if (normprocinfo != NULL)
|
||||
NormCenters(normalizeprocinfo, collation, newCenters);
|
||||
NormCenters(typeInfo, collation, newCenters);
|
||||
}
|
||||
|
||||
/*
|
||||
@@ -267,7 +266,6 @@ ElkanKmeans(Relation index, VectorArray samples, VectorArray centers, const Ivff
|
||||
{
|
||||
FmgrInfo *procinfo;
|
||||
FmgrInfo *normprocinfo;
|
||||
FmgrInfo *normalizeprocinfo;
|
||||
Oid collation;
|
||||
int dimensions = centers->dim;
|
||||
int numCenters = centers->maxlen;
|
||||
@@ -315,7 +313,6 @@ ElkanKmeans(Relation index, VectorArray samples, VectorArray centers, const Ivff
|
||||
/* Set support functions */
|
||||
procinfo = index_getprocinfo(index, 1, IVFFLAT_KMEANS_DISTANCE_PROC);
|
||||
normprocinfo = IvfflatOptionalProcInfo(index, IVFFLAT_KMEANS_NORM_PROC);
|
||||
normalizeprocinfo = IvfflatOptionalProcInfo(index, IVFFLAT_NORMALIZE_PROC);
|
||||
collation = index->rd_indcollation[0];
|
||||
|
||||
/* Use memory context */
|
||||
@@ -477,7 +474,7 @@ ElkanKmeans(Relation index, VectorArray samples, VectorArray centers, const Ivff
|
||||
}
|
||||
|
||||
/* Step 4: For each center c, let m(c) be mean of all points assigned */
|
||||
ComputeNewCenters(samples, agg, newCenters, centerCounts, closestCenters, normprocinfo, normalizeprocinfo, collation, typeInfo);
|
||||
ComputeNewCenters(samples, agg, newCenters, centerCounts, closestCenters, normprocinfo, collation, typeInfo);
|
||||
|
||||
/* Step 5 */
|
||||
for (int j = 0; j < numCenters; j++)
|
||||
|
||||
@@ -209,9 +209,9 @@ GetScanValue(IndexScanDesc scan)
|
||||
Assert(!VARATT_IS_COMPRESSED(DatumGetPointer(value)));
|
||||
Assert(!VARATT_IS_EXTENDED(DatumGetPointer(value)));
|
||||
|
||||
/* Check normprocinfo since normalizeprocinfo not set for vector */
|
||||
/* Normalize if needed */
|
||||
if (so->normprocinfo != NULL)
|
||||
value = IvfflatNormValue(so->normalizeprocinfo, so->collation, value);
|
||||
value = IvfflatNormValue(so->typeInfo, so->collation, value);
|
||||
}
|
||||
|
||||
return value;
|
||||
@@ -242,6 +242,7 @@ ivfflatbeginscan(Relation index, int nkeys, int norderbys)
|
||||
probes = lists;
|
||||
|
||||
so = (IvfflatScanOpaque) palloc(offsetof(IvfflatScanOpaqueData, lists) + probes * sizeof(IvfflatScanList));
|
||||
so->typeInfo = IvfflatGetTypeInfo(index);
|
||||
so->first = true;
|
||||
so->probes = probes;
|
||||
so->dimensions = dimensions;
|
||||
@@ -249,7 +250,6 @@ ivfflatbeginscan(Relation index, int nkeys, int norderbys)
|
||||
/* Set support functions */
|
||||
so->procinfo = index_getprocinfo(index, 1, IVFFLAT_DISTANCE_PROC);
|
||||
so->normprocinfo = IvfflatOptionalProcInfo(index, IVFFLAT_NORM_PROC);
|
||||
so->normalizeprocinfo = IvfflatOptionalProcInfo(index, IVFFLAT_NORMALIZE_PROC);
|
||||
so->collation = index->rd_indcollation[0];
|
||||
|
||||
/* Create tuple description for sorting */
|
||||
|
||||
@@ -68,12 +68,9 @@ IvfflatOptionalProcInfo(Relation index, uint16 procnum)
|
||||
* Normalize value
|
||||
*/
|
||||
Datum
|
||||
IvfflatNormValue(FmgrInfo *procinfo, Oid collation, Datum value)
|
||||
IvfflatNormValue(const IvfflatTypeInfo * typeInfo, Oid collation, Datum value)
|
||||
{
|
||||
if (procinfo == NULL)
|
||||
return DirectFunctionCall1(l2_normalize, value);
|
||||
|
||||
return FunctionCall1Coll(procinfo, collation, value);
|
||||
return DirectFunctionCall1Coll(typeInfo->normalize, collation, value);
|
||||
}
|
||||
|
||||
/*
|
||||
@@ -228,6 +225,10 @@ IvfflatUpdateList(Relation index, ListInfo listInfo,
|
||||
}
|
||||
}
|
||||
|
||||
PGDLLEXPORT Datum l2_normalize(PG_FUNCTION_ARGS);
|
||||
PGDLLEXPORT Datum halfvec_l2_normalize(PG_FUNCTION_ARGS);
|
||||
PGDLLEXPORT Datum sparsevec_l2_normalize(PG_FUNCTION_ARGS);
|
||||
|
||||
static void
|
||||
VectorUpdateCenter(Pointer v, int dimensions, float *x)
|
||||
{
|
||||
@@ -307,6 +308,7 @@ IvfflatGetTypeInfo(Relation index)
|
||||
{
|
||||
static const IvfflatTypeInfo typeInfo = {
|
||||
.maxDimensions = IVFFLAT_MAX_DIM,
|
||||
.normalize = l2_normalize,
|
||||
.updateCenter = VectorUpdateCenter,
|
||||
.sumCenter = VectorSumCenter
|
||||
};
|
||||
@@ -323,6 +325,7 @@ ivfflat_halfvec_support(PG_FUNCTION_ARGS)
|
||||
{
|
||||
static const IvfflatTypeInfo typeInfo = {
|
||||
.maxDimensions = IVFFLAT_MAX_DIM * 2,
|
||||
.normalize = halfvec_l2_normalize,
|
||||
.updateCenter = HalfvecUpdateCenter,
|
||||
.sumCenter = HalfvecSumCenter
|
||||
};
|
||||
@@ -336,6 +339,7 @@ ivfflat_bit_support(PG_FUNCTION_ARGS)
|
||||
{
|
||||
static const IvfflatTypeInfo typeInfo = {
|
||||
.maxDimensions = IVFFLAT_MAX_DIM * 32,
|
||||
.normalize = NULL,
|
||||
.updateCenter = BitUpdateCenter,
|
||||
.sumCenter = BitSumCenter
|
||||
};
|
||||
|
||||
@@ -21,6 +21,5 @@ typedef struct Vector
|
||||
Vector *InitVector(int dim);
|
||||
void PrintVector(char *msg, Vector * vector);
|
||||
int vector_cmp_internal(Vector * a, Vector * b);
|
||||
PGDLLEXPORT Datum l2_normalize(PG_FUNCTION_ARGS);
|
||||
|
||||
#endif
|
||||
|
||||
Reference in New Issue
Block a user