Reduced support functions for IVFFlat - #527

This commit is contained in:
Andrew Kane
2024-04-25 13:56:20 -07:00
parent 1fdfff7349
commit 5dec500879
10 changed files with 36 additions and 41 deletions

View File

@@ -63,7 +63,7 @@ CREATE OPERATOR CLASS bit_hamming_ops
OPERATOR 1 <~> (bit, bit) FOR ORDER BY float_ops, OPERATOR 1 <~> (bit, bit) FOR ORDER BY float_ops,
FUNCTION 1 hamming_distance(bit, bit), FUNCTION 1 hamming_distance(bit, bit),
FUNCTION 3 hamming_distance(bit, bit), FUNCTION 3 hamming_distance(bit, bit),
FUNCTION 6 ivfflat_bit_support(internal); FUNCTION 5 ivfflat_bit_support(internal);
CREATE OPERATOR CLASS bit_hamming_ops CREATE OPERATOR CLASS bit_hamming_ops
FOR TYPE bit USING hnsw AS FOR TYPE bit USING hnsw AS
@@ -330,7 +330,7 @@ CREATE OPERATOR CLASS halfvec_l2_ops
OPERATOR 1 <-> (halfvec, halfvec) FOR ORDER BY float_ops, OPERATOR 1 <-> (halfvec, halfvec) FOR ORDER BY float_ops,
FUNCTION 1 halfvec_l2_squared_distance(halfvec, halfvec), FUNCTION 1 halfvec_l2_squared_distance(halfvec, halfvec),
FUNCTION 3 l2_distance(halfvec, halfvec), FUNCTION 3 l2_distance(halfvec, halfvec),
FUNCTION 6 ivfflat_halfvec_support(internal); FUNCTION 5 ivfflat_halfvec_support(internal);
CREATE OPERATOR CLASS halfvec_ip_ops CREATE OPERATOR CLASS halfvec_ip_ops
FOR TYPE halfvec USING ivfflat AS FOR TYPE halfvec USING ivfflat AS
@@ -338,8 +338,7 @@ CREATE OPERATOR CLASS halfvec_ip_ops
FUNCTION 1 halfvec_negative_inner_product(halfvec, halfvec), FUNCTION 1 halfvec_negative_inner_product(halfvec, halfvec),
FUNCTION 3 halfvec_spherical_distance(halfvec, halfvec), FUNCTION 3 halfvec_spherical_distance(halfvec, halfvec),
FUNCTION 4 l2_norm(halfvec), FUNCTION 4 l2_norm(halfvec),
FUNCTION 5 l2_normalize(halfvec), FUNCTION 5 ivfflat_halfvec_support(internal);
FUNCTION 6 ivfflat_halfvec_support(internal);
CREATE OPERATOR CLASS halfvec_cosine_ops CREATE OPERATOR CLASS halfvec_cosine_ops
FOR TYPE halfvec USING ivfflat AS FOR TYPE halfvec USING ivfflat AS
@@ -348,8 +347,7 @@ CREATE OPERATOR CLASS halfvec_cosine_ops
FUNCTION 2 l2_norm(halfvec), FUNCTION 2 l2_norm(halfvec),
FUNCTION 3 halfvec_spherical_distance(halfvec, halfvec), FUNCTION 3 halfvec_spherical_distance(halfvec, halfvec),
FUNCTION 4 l2_norm(halfvec), FUNCTION 4 l2_norm(halfvec),
FUNCTION 5 l2_normalize(halfvec), FUNCTION 5 ivfflat_halfvec_support(internal);
FUNCTION 6 ivfflat_halfvec_support(internal);
CREATE OPERATOR CLASS halfvec_l2_ops CREATE OPERATOR CLASS halfvec_l2_ops
FOR TYPE halfvec USING hnsw AS FOR TYPE halfvec USING hnsw AS

View File

@@ -358,7 +358,7 @@ CREATE OPERATOR CLASS bit_hamming_ops
OPERATOR 1 <~> (bit, bit) FOR ORDER BY float_ops, OPERATOR 1 <~> (bit, bit) FOR ORDER BY float_ops,
FUNCTION 1 hamming_distance(bit, bit), FUNCTION 1 hamming_distance(bit, bit),
FUNCTION 3 hamming_distance(bit, bit), FUNCTION 3 hamming_distance(bit, bit),
FUNCTION 6 ivfflat_bit_support(internal); FUNCTION 5 ivfflat_bit_support(internal);
CREATE OPERATOR CLASS bit_hamming_ops CREATE OPERATOR CLASS bit_hamming_ops
FOR TYPE bit USING hnsw AS FOR TYPE bit USING hnsw AS
@@ -641,7 +641,7 @@ CREATE OPERATOR CLASS halfvec_l2_ops
OPERATOR 1 <-> (halfvec, halfvec) FOR ORDER BY float_ops, OPERATOR 1 <-> (halfvec, halfvec) FOR ORDER BY float_ops,
FUNCTION 1 halfvec_l2_squared_distance(halfvec, halfvec), FUNCTION 1 halfvec_l2_squared_distance(halfvec, halfvec),
FUNCTION 3 l2_distance(halfvec, halfvec), FUNCTION 3 l2_distance(halfvec, halfvec),
FUNCTION 6 ivfflat_halfvec_support(internal); FUNCTION 5 ivfflat_halfvec_support(internal);
CREATE OPERATOR CLASS halfvec_ip_ops CREATE OPERATOR CLASS halfvec_ip_ops
FOR TYPE halfvec USING ivfflat AS FOR TYPE halfvec USING ivfflat AS
@@ -649,8 +649,7 @@ CREATE OPERATOR CLASS halfvec_ip_ops
FUNCTION 1 halfvec_negative_inner_product(halfvec, halfvec), FUNCTION 1 halfvec_negative_inner_product(halfvec, halfvec),
FUNCTION 3 halfvec_spherical_distance(halfvec, halfvec), FUNCTION 3 halfvec_spherical_distance(halfvec, halfvec),
FUNCTION 4 l2_norm(halfvec), FUNCTION 4 l2_norm(halfvec),
FUNCTION 5 l2_normalize(halfvec), FUNCTION 5 ivfflat_halfvec_support(internal);
FUNCTION 6 ivfflat_halfvec_support(internal);
CREATE OPERATOR CLASS halfvec_cosine_ops CREATE OPERATOR CLASS halfvec_cosine_ops
FOR TYPE halfvec USING ivfflat AS FOR TYPE halfvec USING ivfflat AS
@@ -659,8 +658,7 @@ CREATE OPERATOR CLASS halfvec_cosine_ops
FUNCTION 2 l2_norm(halfvec), FUNCTION 2 l2_norm(halfvec),
FUNCTION 3 halfvec_spherical_distance(halfvec, halfvec), FUNCTION 3 halfvec_spherical_distance(halfvec, halfvec),
FUNCTION 4 l2_norm(halfvec), FUNCTION 4 l2_norm(halfvec),
FUNCTION 5 l2_normalize(halfvec), FUNCTION 5 ivfflat_halfvec_support(internal);
FUNCTION 6 ivfflat_halfvec_support(internal);
CREATE OPERATOR CLASS halfvec_l2_ops CREATE OPERATOR CLASS halfvec_l2_ops
FOR TYPE halfvec USING hnsw AS FOR TYPE halfvec USING hnsw AS

View File

@@ -63,7 +63,7 @@ AddSample(Datum *values, IvfflatBuildState * buildstate)
if (!IvfflatCheckNorm(buildstate->kmeansnormprocinfo, buildstate->collation, value)) if (!IvfflatCheckNorm(buildstate->kmeansnormprocinfo, buildstate->collation, value))
return; return;
value = IvfflatNormValue(buildstate->normalizeprocinfo, buildstate->collation, value); value = IvfflatNormValue(buildstate->typeInfo, buildstate->collation, value);
} }
if (samples->length < targsamples) if (samples->length < targsamples)
@@ -161,7 +161,7 @@ AddTupleToSort(Relation index, ItemPointer tid, Datum *values, IvfflatBuildState
if (!IvfflatCheckNorm(buildstate->normprocinfo, buildstate->collation, value)) if (!IvfflatCheckNorm(buildstate->normprocinfo, buildstate->collation, value))
return; return;
value = IvfflatNormValue(buildstate->normalizeprocinfo, buildstate->collation, value); value = IvfflatNormValue(buildstate->typeInfo, buildstate->collation, value);
} }
/* Find the list that minimizes the distance */ /* Find the list that minimizes the distance */
@@ -351,7 +351,6 @@ InitBuildState(IvfflatBuildState * buildstate, Relation heap, Relation index, In
buildstate->procinfo = index_getprocinfo(index, 1, IVFFLAT_DISTANCE_PROC); buildstate->procinfo = index_getprocinfo(index, 1, IVFFLAT_DISTANCE_PROC);
buildstate->normprocinfo = IvfflatOptionalProcInfo(index, IVFFLAT_NORM_PROC); buildstate->normprocinfo = IvfflatOptionalProcInfo(index, IVFFLAT_NORM_PROC);
buildstate->kmeansnormprocinfo = IvfflatOptionalProcInfo(index, IVFFLAT_KMEANS_NORM_PROC); buildstate->kmeansnormprocinfo = IvfflatOptionalProcInfo(index, IVFFLAT_KMEANS_NORM_PROC);
buildstate->normalizeprocinfo = IvfflatOptionalProcInfo(index, IVFFLAT_NORMALIZE_PROC);
buildstate->collation = index->rd_indcollation[0]; buildstate->collation = index->rd_indcollation[0];
/* Require more than one dimension for spherical k-means */ /* Require more than one dimension for spherical k-means */

View File

@@ -188,7 +188,7 @@ ivfflathandler(PG_FUNCTION_ARGS)
IndexAmRoutine *amroutine = makeNode(IndexAmRoutine); IndexAmRoutine *amroutine = makeNode(IndexAmRoutine);
amroutine->amstrategies = 0; amroutine->amstrategies = 0;
amroutine->amsupport = 6; amroutine->amsupport = 5;
#if PG_VERSION_NUM >= 130000 #if PG_VERSION_NUM >= 130000
amroutine->amoptsprocnum = 0; amroutine->amoptsprocnum = 0;
#endif #endif

View File

@@ -28,8 +28,7 @@
#define IVFFLAT_NORM_PROC 2 #define IVFFLAT_NORM_PROC 2
#define IVFFLAT_KMEANS_DISTANCE_PROC 3 #define IVFFLAT_KMEANS_DISTANCE_PROC 3
#define IVFFLAT_KMEANS_NORM_PROC 4 #define IVFFLAT_KMEANS_NORM_PROC 4
#define IVFFLAT_NORMALIZE_PROC 5 #define IVFFLAT_TYPE_INFO_PROC 5
#define IVFFLAT_TYPE_INFO_PROC 6
#define IVFFLAT_VERSION 1 #define IVFFLAT_VERSION 1
#define IVFFLAT_MAGIC_NUMBER 0x14FF1A7 #define IVFFLAT_MAGIC_NUMBER 0x14FF1A7
@@ -153,6 +152,7 @@ typedef struct IvfflatLeader
typedef struct IvfflatTypeInfo typedef struct IvfflatTypeInfo
{ {
int maxDimensions; int maxDimensions;
Datum (*normalize) (PG_FUNCTION_ARGS);
void (*updateCenter) (Pointer v, int dimensions, float *x); void (*updateCenter) (Pointer v, int dimensions, float *x);
void (*sumCenter) (Pointer v, float *x); void (*sumCenter) (Pointer v, float *x);
} IvfflatTypeInfo; } IvfflatTypeInfo;
@@ -177,7 +177,6 @@ typedef struct IvfflatBuildState
FmgrInfo *procinfo; FmgrInfo *procinfo;
FmgrInfo *normprocinfo; FmgrInfo *normprocinfo;
FmgrInfo *kmeansnormprocinfo; FmgrInfo *kmeansnormprocinfo;
FmgrInfo *normalizeprocinfo;
Oid collation; Oid collation;
/* Variables */ /* Variables */
@@ -245,6 +244,7 @@ typedef struct IvfflatScanList
typedef struct IvfflatScanOpaqueData typedef struct IvfflatScanOpaqueData
{ {
const IvfflatTypeInfo *typeInfo;
int probes; int probes;
int dimensions; int dimensions;
bool first; bool first;
@@ -258,7 +258,6 @@ typedef struct IvfflatScanOpaqueData
/* Support functions */ /* Support functions */
FmgrInfo *procinfo; FmgrInfo *procinfo;
FmgrInfo *normprocinfo; FmgrInfo *normprocinfo;
FmgrInfo *normalizeprocinfo;
Oid collation; Oid collation;
Datum (*distfunc) (FmgrInfo *flinfo, Oid collation, Datum arg1, Datum arg2); Datum (*distfunc) (FmgrInfo *flinfo, Oid collation, Datum arg1, Datum arg2);
@@ -279,7 +278,7 @@ VectorArray VectorArrayInit(int maxlen, int dimensions, Size itemsize);
void VectorArrayFree(VectorArray arr); void VectorArrayFree(VectorArray arr);
void IvfflatKmeans(Relation index, VectorArray samples, VectorArray centers, const IvfflatTypeInfo * typeInfo); void IvfflatKmeans(Relation index, VectorArray samples, VectorArray centers, const IvfflatTypeInfo * typeInfo);
FmgrInfo *IvfflatOptionalProcInfo(Relation index, uint16 procnum); FmgrInfo *IvfflatOptionalProcInfo(Relation index, uint16 procnum);
Datum IvfflatNormValue(FmgrInfo *procinfo, Oid collation, Datum value); Datum IvfflatNormValue(const IvfflatTypeInfo * typeInfo, Oid collation, Datum value);
bool IvfflatCheckNorm(FmgrInfo *procinfo, Oid collation, Datum value); bool IvfflatCheckNorm(FmgrInfo *procinfo, Oid collation, Datum value);
int IvfflatGetLists(Relation index); int IvfflatGetLists(Relation index);
void IvfflatGetMetaPageInfo(Relation index, int *lists, int *dimensions); void IvfflatGetMetaPageInfo(Relation index, int *lists, int *dimensions);

View File

@@ -67,6 +67,7 @@ FindInsertPage(Relation index, Datum *values, BlockNumber *insertPage, ListInfo
static void static void
InsertTuple(Relation index, Datum *values, bool *isnull, ItemPointer heap_tid, Relation heapRel) InsertTuple(Relation index, Datum *values, bool *isnull, ItemPointer heap_tid, Relation heapRel)
{ {
const IvfflatTypeInfo *typeInfo = IvfflatGetTypeInfo(index);
IndexTuple itup; IndexTuple itup;
Datum value; Datum value;
FmgrInfo *normprocinfo; FmgrInfo *normprocinfo;
@@ -90,7 +91,7 @@ InsertTuple(Relation index, Datum *values, bool *isnull, ItemPointer heap_tid, R
if (!IvfflatCheckNorm(normprocinfo, collation, value)) if (!IvfflatCheckNorm(normprocinfo, collation, value))
return; return;
value = IvfflatNormValue(IvfflatOptionalProcInfo(index, IVFFLAT_NORMALIZE_PROC), collation, value); value = IvfflatNormValue(typeInfo, collation, value);
} }
/* Find the insert page - sets the page and list info */ /* Find the insert page - sets the page and list info */

View File

@@ -92,7 +92,7 @@ InitCenters(Relation index, VectorArray samples, VectorArray centers, float *low
* Norm centers * Norm centers
*/ */
static void static void
NormCenters(FmgrInfo *normalizeprocinfo, Oid collation, VectorArray centers) NormCenters(const IvfflatTypeInfo * typeInfo, Oid collation, VectorArray centers)
{ {
MemoryContext normCtx = AllocSetContextCreate(CurrentMemoryContext, MemoryContext normCtx = AllocSetContextCreate(CurrentMemoryContext,
"Ivfflat norm temporary context", "Ivfflat norm temporary context",
@@ -102,7 +102,7 @@ NormCenters(FmgrInfo *normalizeprocinfo, Oid collation, VectorArray centers)
for (int j = 0; j < centers->length; j++) for (int j = 0; j < centers->length; j++)
{ {
Datum center = PointerGetDatum(VectorArrayGet(centers, j)); Datum center = PointerGetDatum(VectorArrayGet(centers, j));
Datum newCenter = IvfflatNormValue(normalizeprocinfo, collation, center); Datum newCenter = IvfflatNormValue(typeInfo, collation, center);
Size size = VARSIZE_ANY(DatumGetPointer(newCenter)); Size size = VARSIZE_ANY(DatumGetPointer(newCenter));
if (size > centers->itemsize) if (size > centers->itemsize)
@@ -123,9 +123,8 @@ static void
RandomCenters(Relation index, VectorArray centers, const IvfflatTypeInfo * typeInfo) RandomCenters(Relation index, VectorArray centers, const IvfflatTypeInfo * typeInfo)
{ {
int dimensions = centers->dim; int dimensions = centers->dim;
Oid collation = index->rd_indcollation[0];
FmgrInfo *normprocinfo = IvfflatOptionalProcInfo(index, IVFFLAT_KMEANS_NORM_PROC); FmgrInfo *normprocinfo = IvfflatOptionalProcInfo(index, IVFFLAT_KMEANS_NORM_PROC);
FmgrInfo *normalizeprocinfo = IvfflatOptionalProcInfo(index, IVFFLAT_NORMALIZE_PROC); Oid collation = index->rd_indcollation[0];
float *x = (float *) palloc(sizeof(float) * dimensions); float *x = (float *) palloc(sizeof(float) * dimensions);
/* Fill with random data */ /* Fill with random data */
@@ -142,7 +141,7 @@ RandomCenters(Relation index, VectorArray centers, const IvfflatTypeInfo * typeI
} }
if (normprocinfo != NULL) if (normprocinfo != NULL)
NormCenters(normalizeprocinfo, collation, centers); NormCenters(typeInfo, collation, centers);
pfree(x); pfree(x);
} }
@@ -196,7 +195,7 @@ UpdateCenters(float *agg, VectorArray centers, const IvfflatTypeInfo * typeInfo)
* Compute new centers * Compute new centers
*/ */
static void static void
ComputeNewCenters(VectorArray samples, float *agg, VectorArray newCenters, int *centerCounts, int *closestCenters, FmgrInfo *normprocinfo, FmgrInfo *normalizeprocinfo, Oid collation, const IvfflatTypeInfo * typeInfo) ComputeNewCenters(VectorArray samples, float *agg, VectorArray newCenters, int *centerCounts, int *closestCenters, FmgrInfo *normprocinfo, Oid collation, const IvfflatTypeInfo * typeInfo)
{ {
int dimensions = newCenters->dim; int dimensions = newCenters->dim;
int numCenters = newCenters->length; int numCenters = newCenters->length;
@@ -251,7 +250,7 @@ ComputeNewCenters(VectorArray samples, float *agg, VectorArray newCenters, int *
/* Normalize if needed */ /* Normalize if needed */
if (normprocinfo != NULL) if (normprocinfo != NULL)
NormCenters(normalizeprocinfo, collation, newCenters); NormCenters(typeInfo, collation, newCenters);
} }
/* /*
@@ -267,7 +266,6 @@ ElkanKmeans(Relation index, VectorArray samples, VectorArray centers, const Ivff
{ {
FmgrInfo *procinfo; FmgrInfo *procinfo;
FmgrInfo *normprocinfo; FmgrInfo *normprocinfo;
FmgrInfo *normalizeprocinfo;
Oid collation; Oid collation;
int dimensions = centers->dim; int dimensions = centers->dim;
int numCenters = centers->maxlen; int numCenters = centers->maxlen;
@@ -315,7 +313,6 @@ ElkanKmeans(Relation index, VectorArray samples, VectorArray centers, const Ivff
/* Set support functions */ /* Set support functions */
procinfo = index_getprocinfo(index, 1, IVFFLAT_KMEANS_DISTANCE_PROC); procinfo = index_getprocinfo(index, 1, IVFFLAT_KMEANS_DISTANCE_PROC);
normprocinfo = IvfflatOptionalProcInfo(index, IVFFLAT_KMEANS_NORM_PROC); normprocinfo = IvfflatOptionalProcInfo(index, IVFFLAT_KMEANS_NORM_PROC);
normalizeprocinfo = IvfflatOptionalProcInfo(index, IVFFLAT_NORMALIZE_PROC);
collation = index->rd_indcollation[0]; collation = index->rd_indcollation[0];
/* Use memory context */ /* Use memory context */
@@ -477,7 +474,7 @@ ElkanKmeans(Relation index, VectorArray samples, VectorArray centers, const Ivff
} }
/* Step 4: For each center c, let m(c) be mean of all points assigned */ /* Step 4: For each center c, let m(c) be mean of all points assigned */
ComputeNewCenters(samples, agg, newCenters, centerCounts, closestCenters, normprocinfo, normalizeprocinfo, collation, typeInfo); ComputeNewCenters(samples, agg, newCenters, centerCounts, closestCenters, normprocinfo, collation, typeInfo);
/* Step 5 */ /* Step 5 */
for (int j = 0; j < numCenters; j++) for (int j = 0; j < numCenters; j++)

View File

@@ -209,9 +209,9 @@ GetScanValue(IndexScanDesc scan)
Assert(!VARATT_IS_COMPRESSED(DatumGetPointer(value))); Assert(!VARATT_IS_COMPRESSED(DatumGetPointer(value)));
Assert(!VARATT_IS_EXTENDED(DatumGetPointer(value))); Assert(!VARATT_IS_EXTENDED(DatumGetPointer(value)));
/* Check normprocinfo since normalizeprocinfo not set for vector */ /* Normalize if needed */
if (so->normprocinfo != NULL) if (so->normprocinfo != NULL)
value = IvfflatNormValue(so->normalizeprocinfo, so->collation, value); value = IvfflatNormValue(so->typeInfo, so->collation, value);
} }
return value; return value;
@@ -242,6 +242,7 @@ ivfflatbeginscan(Relation index, int nkeys, int norderbys)
probes = lists; probes = lists;
so = (IvfflatScanOpaque) palloc(offsetof(IvfflatScanOpaqueData, lists) + probes * sizeof(IvfflatScanList)); so = (IvfflatScanOpaque) palloc(offsetof(IvfflatScanOpaqueData, lists) + probes * sizeof(IvfflatScanList));
so->typeInfo = IvfflatGetTypeInfo(index);
so->first = true; so->first = true;
so->probes = probes; so->probes = probes;
so->dimensions = dimensions; so->dimensions = dimensions;
@@ -249,7 +250,6 @@ ivfflatbeginscan(Relation index, int nkeys, int norderbys)
/* Set support functions */ /* Set support functions */
so->procinfo = index_getprocinfo(index, 1, IVFFLAT_DISTANCE_PROC); so->procinfo = index_getprocinfo(index, 1, IVFFLAT_DISTANCE_PROC);
so->normprocinfo = IvfflatOptionalProcInfo(index, IVFFLAT_NORM_PROC); so->normprocinfo = IvfflatOptionalProcInfo(index, IVFFLAT_NORM_PROC);
so->normalizeprocinfo = IvfflatOptionalProcInfo(index, IVFFLAT_NORMALIZE_PROC);
so->collation = index->rd_indcollation[0]; so->collation = index->rd_indcollation[0];
/* Create tuple description for sorting */ /* Create tuple description for sorting */

View File

@@ -68,12 +68,9 @@ IvfflatOptionalProcInfo(Relation index, uint16 procnum)
* Normalize value * Normalize value
*/ */
Datum Datum
IvfflatNormValue(FmgrInfo *procinfo, Oid collation, Datum value) IvfflatNormValue(const IvfflatTypeInfo * typeInfo, Oid collation, Datum value)
{ {
if (procinfo == NULL) return DirectFunctionCall1Coll(typeInfo->normalize, collation, value);
return DirectFunctionCall1(l2_normalize, value);
return FunctionCall1Coll(procinfo, collation, value);
} }
/* /*
@@ -228,6 +225,10 @@ IvfflatUpdateList(Relation index, ListInfo listInfo,
} }
} }
PGDLLEXPORT Datum l2_normalize(PG_FUNCTION_ARGS);
PGDLLEXPORT Datum halfvec_l2_normalize(PG_FUNCTION_ARGS);
PGDLLEXPORT Datum sparsevec_l2_normalize(PG_FUNCTION_ARGS);
static void static void
VectorUpdateCenter(Pointer v, int dimensions, float *x) VectorUpdateCenter(Pointer v, int dimensions, float *x)
{ {
@@ -307,6 +308,7 @@ IvfflatGetTypeInfo(Relation index)
{ {
static const IvfflatTypeInfo typeInfo = { static const IvfflatTypeInfo typeInfo = {
.maxDimensions = IVFFLAT_MAX_DIM, .maxDimensions = IVFFLAT_MAX_DIM,
.normalize = l2_normalize,
.updateCenter = VectorUpdateCenter, .updateCenter = VectorUpdateCenter,
.sumCenter = VectorSumCenter .sumCenter = VectorSumCenter
}; };
@@ -323,6 +325,7 @@ ivfflat_halfvec_support(PG_FUNCTION_ARGS)
{ {
static const IvfflatTypeInfo typeInfo = { static const IvfflatTypeInfo typeInfo = {
.maxDimensions = IVFFLAT_MAX_DIM * 2, .maxDimensions = IVFFLAT_MAX_DIM * 2,
.normalize = halfvec_l2_normalize,
.updateCenter = HalfvecUpdateCenter, .updateCenter = HalfvecUpdateCenter,
.sumCenter = HalfvecSumCenter .sumCenter = HalfvecSumCenter
}; };
@@ -336,6 +339,7 @@ ivfflat_bit_support(PG_FUNCTION_ARGS)
{ {
static const IvfflatTypeInfo typeInfo = { static const IvfflatTypeInfo typeInfo = {
.maxDimensions = IVFFLAT_MAX_DIM * 32, .maxDimensions = IVFFLAT_MAX_DIM * 32,
.normalize = NULL,
.updateCenter = BitUpdateCenter, .updateCenter = BitUpdateCenter,
.sumCenter = BitSumCenter .sumCenter = BitSumCenter
}; };

View File

@@ -21,6 +21,5 @@ typedef struct Vector
Vector *InitVector(int dim); Vector *InitVector(int dim);
void PrintVector(char *msg, Vector * vector); void PrintVector(char *msg, Vector * vector);
int vector_cmp_internal(Vector * a, Vector * b); int vector_cmp_internal(Vector * a, Vector * b);
PGDLLEXPORT Datum l2_normalize(PG_FUNCTION_ARGS);
#endif #endif