diff --git a/src/ivfbuild.c b/src/ivfbuild.c index 8a7b52d..cfff09d 100644 --- a/src/ivfbuild.c +++ b/src/ivfbuild.c @@ -365,8 +365,7 @@ InitBuildState(IvfflatBuildState * buildstate, Relation heap, Relation index, In buildstate->slot = MakeSingleTupleTableSlot(buildstate->tupdesc, &TTSOpsVirtual); - /* TODO Fix item size */ - buildstate->centers = VectorArrayInit(buildstate->lists, buildstate->dimensions, VECTOR_SIZE(buildstate->dimensions)); + buildstate->centers = VectorArrayInit(buildstate->lists, buildstate->dimensions, buildstate->typeInfo->itemSize(buildstate->dimensions)); buildstate->listInfo = palloc(sizeof(ListInfo) * buildstate->lists); buildstate->tmpCtx = AllocSetContextCreate(CurrentMemoryContext, diff --git a/src/ivfflat.h b/src/ivfflat.h index dd96f57..d3fa219 100644 --- a/src/ivfflat.h +++ b/src/ivfflat.h @@ -153,6 +153,7 @@ typedef struct IvfflatTypeInfo { int maxDimensions; Datum (*normalize) (PG_FUNCTION_ARGS); + Size (*itemSize) (int dimensions); void (*updateCenter) (Pointer v, int dimensions, float *x); void (*sumCenter) (Pointer v, float *x); } IvfflatTypeInfo; diff --git a/src/ivfutils.c b/src/ivfutils.c index 6ecf85a..74a4159 100644 --- a/src/ivfutils.c +++ b/src/ivfutils.c @@ -229,6 +229,24 @@ PGDLLEXPORT Datum l2_normalize(PG_FUNCTION_ARGS); PGDLLEXPORT Datum halfvec_l2_normalize(PG_FUNCTION_ARGS); PGDLLEXPORT Datum sparsevec_l2_normalize(PG_FUNCTION_ARGS); +static Size +VectorItemSize(int dimensions) +{ + return VECTOR_SIZE(dimensions); +} + +static Size +HalfvecItemSize(int dimensions) +{ + return HALFVEC_SIZE(dimensions); +} + +static Size +BitItemSize(int dimensions) +{ + return VARBITTOTALLEN(dimensions); +} + static void VectorUpdateCenter(Pointer v, int dimensions, float *x) { @@ -309,6 +327,7 @@ IvfflatGetTypeInfo(Relation index) static const IvfflatTypeInfo typeInfo = { .maxDimensions = IVFFLAT_MAX_DIM, .normalize = l2_normalize, + .itemSize = VectorItemSize, .updateCenter = VectorUpdateCenter, .sumCenter = VectorSumCenter }; @@ -326,6 +345,7 @@ ivfflat_halfvec_support(PG_FUNCTION_ARGS) static const IvfflatTypeInfo typeInfo = { .maxDimensions = IVFFLAT_MAX_DIM * 2, .normalize = halfvec_l2_normalize, + .itemSize = HalfvecItemSize, .updateCenter = HalfvecUpdateCenter, .sumCenter = HalfvecSumCenter }; @@ -340,6 +360,7 @@ ivfflat_bit_support(PG_FUNCTION_ARGS) static const IvfflatTypeInfo typeInfo = { .maxDimensions = IVFFLAT_MAX_DIM * 32, .normalize = NULL, + .itemSize = BitItemSize, .updateCenter = BitUpdateCenter, .sumCenter = BitSumCenter };