diff --git a/sql/vector--0.6.2--0.7.0.sql b/sql/vector--0.6.2--0.7.0.sql index e054031..84c7f3a 100644 --- a/sql/vector--0.6.2--0.7.0.sql +++ b/sql/vector--0.6.2--0.7.0.sql @@ -23,28 +23,40 @@ CREATE OPERATOR || ( ); CREATE FUNCTION ivfflat_bit_max_dims(internal) RETURNS internal - AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + AS 'MODULE_PATHNAME' LANGUAGE C; CREATE FUNCTION ivfflat_halfvec_max_dims(internal) RETURNS internal - AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + AS 'MODULE_PATHNAME' LANGUAGE C; -CREATE FUNCTION ivfflat_bit_support(internal) RETURNS internal - AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; +CREATE FUNCTION ivfflat_vector_update_center(internal, internal, internal) RETURNS internal + AS 'MODULE_PATHNAME' LANGUAGE C; -CREATE FUNCTION ivfflat_halfvec_support(internal) RETURNS internal - AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; +CREATE FUNCTION ivfflat_bit_update_center(internal, internal, internal) RETURNS internal + AS 'MODULE_PATHNAME' LANGUAGE C; + +CREATE FUNCTION ivfflat_halfvec_update_center(internal, internal, internal) RETURNS internal + AS 'MODULE_PATHNAME' LANGUAGE C; + +CREATE FUNCTION ivfflat_vector_sum_center(internal, internal) RETURNS internal + AS 'MODULE_PATHNAME' LANGUAGE C; + +CREATE FUNCTION ivfflat_bit_sum_center(internal, internal) RETURNS internal + AS 'MODULE_PATHNAME' LANGUAGE C; + +CREATE FUNCTION ivfflat_halfvec_sum_center(internal, internal) RETURNS internal + AS 'MODULE_PATHNAME' LANGUAGE C; CREATE FUNCTION hnsw_bit_max_dims(internal) RETURNS internal - AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + AS 'MODULE_PATHNAME' LANGUAGE C; CREATE FUNCTION hnsw_halfvec_max_dims(internal) RETURNS internal - AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + AS 'MODULE_PATHNAME' LANGUAGE C; CREATE FUNCTION hnsw_sparsevec_max_dims(internal) RETURNS internal - AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + AS 'MODULE_PATHNAME' LANGUAGE C; CREATE FUNCTION hnsw_sparsevec_check_value(internal) RETURNS internal - AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + AS 'MODULE_PATHNAME' LANGUAGE C; CREATE OPERATOR CLASS vector_l1_ops FOR TYPE vector USING hnsw AS @@ -73,7 +85,8 @@ CREATE OPERATOR CLASS bit_hamming_ops FUNCTION 1 hamming_distance(bit, bit), FUNCTION 3 hamming_distance(bit, bit), FUNCTION 6 ivfflat_bit_max_dims(internal), - FUNCTION 7 ivfflat_bit_support(internal); + FUNCTION 7 ivfflat_bit_update_center(internal, internal, internal), + FUNCTION 8 ivfflat_bit_sum_center(internal, internal); CREATE OPERATOR CLASS bit_hamming_ops FOR TYPE bit USING hnsw AS @@ -341,7 +354,8 @@ CREATE OPERATOR CLASS halfvec_l2_ops FUNCTION 1 halfvec_l2_squared_distance(halfvec, halfvec), FUNCTION 3 l2_distance(halfvec, halfvec), FUNCTION 6 ivfflat_halfvec_max_dims(internal), - FUNCTION 7 ivfflat_halfvec_support(internal); + FUNCTION 7 ivfflat_halfvec_update_center(internal, internal, internal), + FUNCTION 8 ivfflat_halfvec_sum_center(internal, internal); CREATE OPERATOR CLASS halfvec_ip_ops FOR TYPE halfvec USING ivfflat AS @@ -351,7 +365,8 @@ CREATE OPERATOR CLASS halfvec_ip_ops FUNCTION 4 l2_norm(halfvec), FUNCTION 5 l2_normalize(halfvec), FUNCTION 6 ivfflat_halfvec_max_dims(internal), - FUNCTION 7 ivfflat_halfvec_support(internal); + FUNCTION 7 ivfflat_halfvec_update_center(internal, internal, internal), + FUNCTION 8 ivfflat_halfvec_sum_center(internal, internal); CREATE OPERATOR CLASS halfvec_cosine_ops FOR TYPE halfvec USING ivfflat AS @@ -362,7 +377,8 @@ CREATE OPERATOR CLASS halfvec_cosine_ops FUNCTION 4 l2_norm(halfvec), FUNCTION 5 l2_normalize(halfvec), FUNCTION 6 ivfflat_halfvec_max_dims(internal), - FUNCTION 7 ivfflat_halfvec_support(internal); + FUNCTION 7 ivfflat_halfvec_update_center(internal, internal, internal), + FUNCTION 8 ivfflat_halfvec_sum_center(internal, internal); CREATE OPERATOR CLASS halfvec_l2_ops FOR TYPE halfvec USING hnsw AS diff --git a/sql/vector.sql b/sql/vector.sql index c168c84..9113b3f 100644 --- a/sql/vector.sql +++ b/sql/vector.sql @@ -264,28 +264,40 @@ COMMENT ON ACCESS METHOD hnsw IS 'hnsw index access method'; -- access method private functions CREATE FUNCTION ivfflat_bit_max_dims(internal) RETURNS internal - AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + AS 'MODULE_PATHNAME' LANGUAGE C; CREATE FUNCTION ivfflat_halfvec_max_dims(internal) RETURNS internal - AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + AS 'MODULE_PATHNAME' LANGUAGE C; -CREATE FUNCTION ivfflat_bit_support(internal) RETURNS internal - AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; +CREATE FUNCTION ivfflat_vector_update_center(internal, internal, internal) RETURNS internal + AS 'MODULE_PATHNAME' LANGUAGE C; -CREATE FUNCTION ivfflat_halfvec_support(internal) RETURNS internal - AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; +CREATE FUNCTION ivfflat_bit_update_center(internal, internal, internal) RETURNS internal + AS 'MODULE_PATHNAME' LANGUAGE C; + +CREATE FUNCTION ivfflat_halfvec_update_center(internal, internal, internal) RETURNS internal + AS 'MODULE_PATHNAME' LANGUAGE C; + +CREATE FUNCTION ivfflat_vector_sum_center(internal, internal) RETURNS internal + AS 'MODULE_PATHNAME' LANGUAGE C; + +CREATE FUNCTION ivfflat_bit_sum_center(internal, internal) RETURNS internal + AS 'MODULE_PATHNAME' LANGUAGE C; + +CREATE FUNCTION ivfflat_halfvec_sum_center(internal, internal) RETURNS internal + AS 'MODULE_PATHNAME' LANGUAGE C; CREATE FUNCTION hnsw_bit_max_dims(internal) RETURNS internal - AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + AS 'MODULE_PATHNAME' LANGUAGE C; CREATE FUNCTION hnsw_halfvec_max_dims(internal) RETURNS internal - AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + AS 'MODULE_PATHNAME' LANGUAGE C; CREATE FUNCTION hnsw_sparsevec_max_dims(internal) RETURNS internal - AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + AS 'MODULE_PATHNAME' LANGUAGE C; CREATE FUNCTION hnsw_sparsevec_check_value(internal) RETURNS internal - AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + AS 'MODULE_PATHNAME' LANGUAGE C; -- vector opclasses @@ -368,7 +380,8 @@ CREATE OPERATOR CLASS bit_hamming_ops FUNCTION 1 hamming_distance(bit, bit), FUNCTION 3 hamming_distance(bit, bit), FUNCTION 6 ivfflat_bit_max_dims(internal), - FUNCTION 7 ivfflat_bit_support(internal); + FUNCTION 7 ivfflat_bit_update_center(internal, internal, internal), + FUNCTION 8 ivfflat_bit_sum_center(internal, internal); CREATE OPERATOR CLASS bit_hamming_ops FOR TYPE bit USING hnsw AS @@ -652,7 +665,8 @@ CREATE OPERATOR CLASS halfvec_l2_ops FUNCTION 1 halfvec_l2_squared_distance(halfvec, halfvec), FUNCTION 3 l2_distance(halfvec, halfvec), FUNCTION 6 ivfflat_halfvec_max_dims(internal), - FUNCTION 7 ivfflat_halfvec_support(internal); + FUNCTION 7 ivfflat_halfvec_update_center(internal, internal, internal), + FUNCTION 8 ivfflat_halfvec_sum_center(internal, internal); CREATE OPERATOR CLASS halfvec_ip_ops FOR TYPE halfvec USING ivfflat AS @@ -662,7 +676,8 @@ CREATE OPERATOR CLASS halfvec_ip_ops FUNCTION 4 l2_norm(halfvec), FUNCTION 5 l2_normalize(halfvec), FUNCTION 6 ivfflat_halfvec_max_dims(internal), - FUNCTION 7 ivfflat_halfvec_support(internal); + FUNCTION 7 ivfflat_halfvec_update_center(internal, internal, internal), + FUNCTION 8 ivfflat_halfvec_sum_center(internal, internal); CREATE OPERATOR CLASS halfvec_cosine_ops FOR TYPE halfvec USING ivfflat AS @@ -673,7 +688,8 @@ CREATE OPERATOR CLASS halfvec_cosine_ops FUNCTION 4 l2_norm(halfvec), FUNCTION 5 l2_normalize(halfvec), FUNCTION 6 ivfflat_halfvec_max_dims(internal), - FUNCTION 7 ivfflat_halfvec_support(internal); + FUNCTION 7 ivfflat_halfvec_update_center(internal, internal, internal), + FUNCTION 8 ivfflat_halfvec_sum_center(internal, internal); CREATE OPERATOR CLASS halfvec_l2_ops FOR TYPE halfvec USING hnsw AS diff --git a/src/halfvec.c b/src/halfvec.c index 32e82dd..30ec15f 100644 --- a/src/halfvec.c +++ b/src/halfvec.c @@ -978,7 +978,7 @@ halfvec_subvector(PG_FUNCTION_ARGS) /* * Internal helper to compare half vectors */ -int +static int halfvec_cmp_internal(HalfVector * a, HalfVector * b) { int dim = Min(a->dim, b->dim); diff --git a/src/halfvec.h b/src/halfvec.h index cb021a6..d20a268 100644 --- a/src/halfvec.h +++ b/src/halfvec.h @@ -44,6 +44,5 @@ typedef struct HalfVector } HalfVector; HalfVector *InitHalfVector(int dim); -int halfvec_cmp_internal(HalfVector * a, HalfVector * b); #endif diff --git a/src/ivfbuild.c b/src/ivfbuild.c index e95036c..0fad43e 100644 --- a/src/ivfbuild.c +++ b/src/ivfbuild.c @@ -319,27 +319,6 @@ InsertTuples(Relation index, IvfflatBuildState * buildstate, ForkNumber forkNum) } } -/* - * Get type - */ -static IvfflatType -IvfflatGetType(Relation index) -{ - FmgrInfo *procinfo = IvfflatOptionalProcInfo(index, IVFFLAT_TYPE_SUPPORT_PROC); - Oid typid = TupleDescAttr(index->rd_att, 0)->atttypid; - IvfflatType result; - - if (procinfo == NULL) - return IVFFLAT_TYPE_VECTOR; - - result = (IvfflatType) DatumGetInt32(FunctionCall1(procinfo, ObjectIdGetDatum(typid))); - - if (result == IVFFLAT_TYPE_UNSUPPORTED) - elog(ERROR, "type not supported for ivfflat index"); - - return result; -} - /* * Get max dimensions */ @@ -358,13 +337,14 @@ GetMaxDimensions(Relation index) * Get item size */ static Size -GetItemSize(IvfflatType type, int dimensions) +GetItemSize(int maxDimensions, int dimensions) { - if (type == IVFFLAT_TYPE_VECTOR) + /* TODO Improve */ + if (maxDimensions == IVFFLAT_MAX_DIM) return VECTOR_SIZE(dimensions); - else if (type == IVFFLAT_TYPE_HALFVEC) + else if (maxDimensions == IVFFLAT_MAX_DIM * 2) return HALFVEC_SIZE(dimensions); - else if (type == IVFFLAT_TYPE_BIT) + else if (maxDimensions == IVFFLAT_MAX_DIM * 32) return VARBITTOTALLEN(dimensions); else elog(ERROR, "Unsupported type"); @@ -381,7 +361,6 @@ InitBuildState(IvfflatBuildState * buildstate, Relation heap, Relation index, In buildstate->heap = heap; buildstate->index = index; buildstate->indexInfo = indexInfo; - buildstate->type = IvfflatGetType(index); buildstate->lists = IvfflatGetLists(index); buildstate->dimensions = TupleDescAttr(index->rd_att, 0)->atttypmod; @@ -421,7 +400,7 @@ InitBuildState(IvfflatBuildState * buildstate, Relation heap, Relation index, In buildstate->slot = MakeSingleTupleTableSlot(buildstate->tupdesc, &TTSOpsVirtual); - buildstate->centers = VectorArrayInit(buildstate->lists, buildstate->dimensions, GetItemSize(buildstate->type, buildstate->dimensions)); + buildstate->centers = VectorArrayInit(buildstate->lists, buildstate->dimensions, GetItemSize(maxDimensions, buildstate->dimensions)); buildstate->listInfo = palloc(sizeof(ListInfo) * buildstate->lists); buildstate->tmpCtx = AllocSetContextCreate(CurrentMemoryContext, @@ -491,7 +470,7 @@ ComputeCenters(IvfflatBuildState * buildstate) } /* Calculate centers */ - IvfflatBench("k-means", IvfflatKmeans(buildstate->index, buildstate->samples, buildstate->centers, buildstate->type)); + IvfflatBench("k-means", IvfflatKmeans(buildstate->index, buildstate->samples, buildstate->centers)); /* Free samples before we allocate more memory */ VectorArrayFree(buildstate->samples); diff --git a/src/ivfflat.c b/src/ivfflat.c index 4ff77fa..9a226a6 100644 --- a/src/ivfflat.c +++ b/src/ivfflat.c @@ -188,7 +188,7 @@ ivfflathandler(PG_FUNCTION_ARGS) IndexAmRoutine *amroutine = makeNode(IndexAmRoutine); amroutine->amstrategies = 0; - amroutine->amsupport = 7; + amroutine->amsupport = 8; #if PG_VERSION_NUM >= 130000 amroutine->amoptsprocnum = 0; #endif diff --git a/src/ivfflat.h b/src/ivfflat.h index 21bcd3a..348d96d 100644 --- a/src/ivfflat.h +++ b/src/ivfflat.h @@ -30,7 +30,8 @@ #define IVFFLAT_KMEANS_NORM_PROC 4 #define IVFFLAT_NORMALIZE_PROC 5 #define IVFFLAT_MAX_DIMS_PROC 6 -#define IVFFLAT_TYPE_SUPPORT_PROC 7 +#define IVFFLAT_UPDATE_CENTER_PROC 7 +#define IVFFLAT_SUM_CENTER_PROC 8 #define IVFFLAT_VERSION 1 #define IVFFLAT_MAGIC_NUMBER 0x14FF1A7 @@ -46,14 +47,6 @@ #define IVFFLAT_MAX_LISTS 32768 #define IVFFLAT_DEFAULT_PROBES 1 -typedef enum IvfflatType -{ - IVFFLAT_TYPE_VECTOR, - IVFFLAT_TYPE_HALFVEC, - IVFFLAT_TYPE_BIT, - IVFFLAT_TYPE_UNSUPPORTED -} IvfflatType; - /* Build phases */ /* PROGRESS_CREATEIDX_SUBPHASE_INITIALIZE is 1 */ #define PROGRESS_IVFFLAT_PHASE_KMEANS 2 @@ -165,7 +158,6 @@ typedef struct IvfflatBuildState Relation heap; Relation index; IndexInfo *indexInfo; - IvfflatType type; /* Settings */ int dimensions; @@ -279,7 +271,7 @@ typedef IvfflatScanOpaqueData * IvfflatScanOpaque; /* Methods */ VectorArray VectorArrayInit(int maxlen, int dimensions, Size itemsize); void VectorArrayFree(VectorArray arr); -void IvfflatKmeans(Relation index, VectorArray samples, VectorArray centers, IvfflatType type); +void IvfflatKmeans(Relation index, VectorArray samples, VectorArray centers); FmgrInfo *IvfflatOptionalProcInfo(Relation index, uint16 procnum); Datum IvfflatNormValue(FmgrInfo *procinfo, Oid collation, Datum value); bool IvfflatCheckNorm(FmgrInfo *procinfo, Oid collation, Datum value); diff --git a/src/ivfkmeans.c b/src/ivfkmeans.c index b7246f0..6ac1ac8 100644 --- a/src/ivfkmeans.c +++ b/src/ivfkmeans.c @@ -13,14 +13,15 @@ #include "utils/memutils.h" #include "vector.h" +/* Support functions */ +PGDLLEXPORT Datum ivfflat_vector_update_center(PG_FUNCTION_ARGS); +PGDLLEXPORT Datum ivfflat_vector_sum_center(PG_FUNCTION_ARGS); + typedef struct KmeansState { - void (*initCenter) (Pointer v, int dimensions); - void (*updateCenter) (Pointer v, float *x); - void (*sumCenter) (Pointer v, float *x); - int (*comp) (const void *a, const void *b); - bool separateAgg; - bool checkDuplicates; + int dimensions; + FmgrInfo *updatecenterprocinfo; + FmgrInfo *sumcenterprocinfo; } KmeansState; /* @@ -126,105 +127,20 @@ NormCenters(FmgrInfo *normalizeprocinfo, Oid collation, VectorArray centers) MemoryContextDelete(normCtx); } -/* - * Compare vectors - */ -static int -CompareVectors(const void *a, const void *b) +static void +UpdateCenter(FmgrInfo *procinfo, Pointer center, int dimensions, float *x) { - return vector_cmp_internal((Vector *) a, (Vector *) b); + if (procinfo == NULL) + DirectFunctionCall3(ivfflat_vector_update_center, PointerGetDatum(center), Int32GetDatum(dimensions), PointerGetDatum(x)); + else + FunctionCall3(procinfo, PointerGetDatum(center), Int32GetDatum(dimensions), PointerGetDatum(x)); } /* - * Compare half vectors - */ -static int -CompareHalfVectors(const void *a, const void *b) -{ - return halfvec_cmp_internal((HalfVector *) a, (HalfVector *) b); -} - -/* - * Compare bit vectors - */ -static int -CompareBitVectors(const void *a, const void *b) -{ - return DirectFunctionCall2(bitcmp, VarBitPGetDatum((VarBit *) a), VarBitPGetDatum((VarBit *) b)); -} - -/* - * Sort vector array + * Quick approach if we have no data */ static void -SortVectorArray(VectorArray arr, KmeansState * kmeansstate) -{ - qsort(arr->items, arr->length, arr->itemsize, kmeansstate->comp); -} - -static void -VectorInitCenter(Pointer v, int dimensions) -{ - Vector *vec = (Vector *) v; - - SET_VARSIZE(vec, VECTOR_SIZE(dimensions)); - vec->dim = dimensions; -} - -static void -HalfvecInitCenter(Pointer v, int dimensions) -{ - HalfVector *vec = (HalfVector *) v; - - SET_VARSIZE(vec, HALFVEC_SIZE(dimensions)); - vec->dim = dimensions; -} - -static void -BitInitCenter(Pointer v, int dimensions) -{ - VarBit *vec = (VarBit *) v; - - SET_VARSIZE(vec, VARBITTOTALLEN(dimensions)); - VARBITLEN(vec) = dimensions; -} - -static void -VectorUpdateCenter(Pointer v, float *x) -{ - Vector *newCenter = (Vector *) v; - - for (int k = 0; k < newCenter->dim; k++) - newCenter->x[k] = x[k]; -} - -static void -HalfvecUpdateCenter(Pointer v, float *x) -{ - HalfVector *newCenter = (HalfVector *) v; - - for (int k = 0; k < newCenter->dim; k++) - newCenter->x[k] = Float4ToHalfUnchecked(x[k]); -} - -static void -BitUpdateCenter(Pointer v, float *x) -{ - VarBit *newCenter = (VarBit *) v; - unsigned char *nx = VARBITS(newCenter); - - for (uint32 k = 0; k < VARBITBYTES(newCenter); k++) - nx[k] = 0; - - for (int k = 0; k < VARBITLEN(newCenter); k++) - nx[k / 8] |= (x[k] > 0.5 ? 1 : 0) << (7 - (k % 8)); -} - -/* - * Quick approach if we have little data - */ -static void -QuickCenters(Relation index, VectorArray samples, VectorArray centers, KmeansState * kmeansstate) +RandomCenters(Relation index, VectorArray centers, KmeansState * kmeansstate) { int dimensions = centers->dim; Oid collation = index->rd_indcollation[0]; @@ -232,24 +148,7 @@ QuickCenters(Relation index, VectorArray samples, VectorArray centers, KmeansSta FmgrInfo *normalizeprocinfo = IvfflatOptionalProcInfo(index, IVFFLAT_NORMALIZE_PROC); float *x = (float *) palloc(sizeof(float) * dimensions); - /* Copy existing vectors while avoiding duplicates */ - if (samples->length > 0) - { - SortVectorArray(samples, kmeansstate); - - for (int i = 0; i < samples->length; i++) - { - Datum vec = PointerGetDatum(VectorArrayGet(samples, i)); - - if (i == 0 || !datumIsEqual(vec, PointerGetDatum(VectorArrayGet(samples, i - 1)), false, -1)) - { - VectorArraySet(centers, centers->length, DatumGetPointer(vec)); - centers->length++; - } - } - } - - /* Fill remaining with random data */ + /* Fill with random data */ while (centers->length < centers->maxlen) { Pointer center = VectorArrayGet(centers, centers->length); @@ -257,13 +156,11 @@ QuickCenters(Relation index, VectorArray samples, VectorArray centers, KmeansSta for (int i = 0; i < dimensions; i++) x[i] = (float) RandomDouble(); - kmeansstate->initCenter(center, dimensions); - kmeansstate->updateCenter(center, x); + UpdateCenter(kmeansstate->updatecenterprocinfo, center, dimensions, x); centers->length++; } - /* Fine if existing vectors are normalized twice */ if (normprocinfo != NULL) NormCenters(normalizeprocinfo, collation, centers); @@ -288,57 +185,39 @@ ShowMemoryUsage(MemoryContext context, Size estimatedSize) #endif static void -VectorSumCenter(Pointer v, float *x) +SumCenter(FmgrInfo *procinfo, Pointer sample, float *x) { - Vector *vec = (Vector *) v; - - for (int k = 0; k < vec->dim; k++) - x[k] += vec->x[k]; -} - -static void -HalfvecSumCenter(Pointer v, float *x) -{ - HalfVector *vec = (HalfVector *) v; - - for (int k = 0; k < vec->dim; k++) - x[k] += HalfToFloat4(vec->x[k]); -} - -static void -BitSumCenter(Pointer v, float *x) -{ - VarBit *vec = (VarBit *) v; - - for (int k = 0; k < VARBITLEN(v); k++) - x[k] += (float) (((VARBITS(vec)[k / 8]) >> (7 - (k % 8))) & 0x01); + if (procinfo == NULL) + DirectFunctionCall2(ivfflat_vector_sum_center, PointerGetDatum(sample), PointerGetDatum(x)); + else + FunctionCall2(procinfo, PointerGetDatum(sample), PointerGetDatum(x)); } /* * Sum centers */ static void -SumCenters(VectorArray samples, VectorArray aggCenters, int *closestCenters, KmeansState * kmeansstate) +SumCenters(VectorArray samples, float *agg, int *closestCenters, KmeansState * kmeansstate) { for (int j = 0; j < samples->length; j++) { - Vector *aggCenter = (Vector *) VectorArrayGet(aggCenters, closestCenters[j]); + float *x = agg + ((int64) closestCenters[j] * kmeansstate->dimensions); - kmeansstate->sumCenter(VectorArrayGet(samples, j), aggCenter->x); + SumCenter(kmeansstate->sumcenterprocinfo, VectorArrayGet(samples, j), x); } } /* - * Set new centers + * Update centers */ static void -UpdateCenters(VectorArray aggCenters, VectorArray newCenters, KmeansState * kmeansstate) +UpdateCenters(float *agg, VectorArray centers, KmeansState * kmeansstate) { - for (int j = 0; j < aggCenters->length; j++) + for (int j = 0; j < centers->length; j++) { - Vector *aggCenter = (Vector *) VectorArrayGet(aggCenters, j); + float *x = agg + ((int64) j * kmeansstate->dimensions); - kmeansstate->updateCenter(VectorArrayGet(newCenters, j), aggCenter->x); + UpdateCenter(kmeansstate->updatecenterprocinfo, VectorArrayGet(centers, j), centers->dim, x); } } @@ -346,25 +225,25 @@ UpdateCenters(VectorArray aggCenters, VectorArray newCenters, KmeansState * kmea * Compute new centers */ static void -ComputeNewCenters(VectorArray samples, VectorArray aggCenters, VectorArray newCenters, int *centerCounts, int *closestCenters, FmgrInfo *normprocinfo, FmgrInfo *normalizeprocinfo, Oid collation, KmeansState * kmeansstate) +ComputeNewCenters(VectorArray samples, float *agg, VectorArray newCenters, int *centerCounts, int *closestCenters, FmgrInfo *normprocinfo, FmgrInfo *normalizeprocinfo, Oid collation, KmeansState * kmeansstate) { - int dimensions = aggCenters->dim; - int numCenters = aggCenters->maxlen; + int dimensions = kmeansstate->dimensions; + int numCenters = newCenters->length; int numSamples = samples->length; /* Reset sum and count */ for (int j = 0; j < numCenters; j++) { - Vector *vec = (Vector *) VectorArrayGet(aggCenters, j); + float *x = agg + ((int64) j * dimensions); for (int k = 0; k < dimensions; k++) - vec->x[k] = 0.0; + x[k] = 0.0; centerCounts[j] = 0; } /* Increment sum of closest center */ - SumCenters(samples, aggCenters, closestCenters, kmeansstate); + SumCenters(samples, agg, closestCenters, kmeansstate); /* Increment count of closest center */ for (int j = 0; j < numSamples; j++) @@ -373,7 +252,7 @@ ComputeNewCenters(VectorArray samples, VectorArray aggCenters, VectorArray newCe /* Divide sum by count */ for (int j = 0; j < numCenters; j++) { - Vector *vec = (Vector *) VectorArrayGet(aggCenters, j); + float *x = agg + ((int64) j * dimensions); if (centerCounts[j] > 0) { @@ -381,24 +260,23 @@ ComputeNewCenters(VectorArray samples, VectorArray aggCenters, VectorArray newCe /* TODO Update bounds */ for (int k = 0; k < dimensions; k++) { - if (isinf(vec->x[k])) - vec->x[k] = vec->x[k] > 0 ? FLT_MAX : -FLT_MAX; + if (isinf(x[k])) + x[k] = x[k] > 0 ? FLT_MAX : -FLT_MAX; } for (int k = 0; k < dimensions; k++) - vec->x[k] /= centerCounts[j]; + x[k] /= centerCounts[j]; } else { /* TODO Handle empty centers properly */ for (int k = 0; k < dimensions; k++) - vec->x[k] = RandomDouble(); + x[k] = RandomDouble(); } } - /* Set new centers if different from agg centers */ - if (kmeansstate->separateAgg) - UpdateCenters(aggCenters, newCenters, kmeansstate); + /* Set new centers */ + UpdateCenters(agg, newCenters, kmeansstate); /* Normalize if needed */ if (normprocinfo != NULL) @@ -424,7 +302,7 @@ ElkanKmeans(Relation index, VectorArray samples, VectorArray centers, KmeansStat int numCenters = centers->maxlen; int numSamples = samples->length; VectorArray newCenters; - VectorArray aggCenters; + float *agg; int *centerCounts; int *closestCenters; float *lowerBound; @@ -439,7 +317,7 @@ ElkanKmeans(Relation index, VectorArray samples, VectorArray centers, KmeansStat Size samplesSize = VECTOR_ARRAY_SIZE(samples->maxlen, samples->itemsize); Size centersSize = VECTOR_ARRAY_SIZE(centers->maxlen, centers->itemsize); Size newCentersSize = VECTOR_ARRAY_SIZE(numCenters, centers->itemsize); - Size aggCentersSize = !kmeansstate->separateAgg ? 0 : VECTOR_ARRAY_SIZE(numCenters, VECTOR_SIZE(dimensions)); + Size aggSize = sizeof(float) * (int64) numCenters * dimensions; Size centerCountsSize = sizeof(int) * numCenters; Size closestCentersSize = sizeof(int) * numSamples; Size lowerBoundSize = sizeof(float) * numSamples * numCenters; @@ -449,7 +327,7 @@ ElkanKmeans(Relation index, VectorArray samples, VectorArray centers, KmeansStat Size newcdistSize = sizeof(float) * numCenters; /* Calculate total size */ - Size totalSize = samplesSize + centersSize + newCentersSize + aggCentersSize + centerCountsSize + closestCentersSize + lowerBoundSize + upperBoundSize + sSize + halfcdistSize + newcdistSize; + Size totalSize = samplesSize + centersSize + newCentersSize + aggSize + centerCountsSize + closestCentersSize + lowerBoundSize + upperBoundSize + sSize + halfcdistSize + newcdistSize; /* Check memory requirements */ /* Add one to error message to ceil */ @@ -477,6 +355,7 @@ ElkanKmeans(Relation index, VectorArray samples, VectorArray centers, KmeansStat /* Allocate space */ /* Use float instead of double to save memory */ + agg = palloc(aggSize); centerCounts = palloc(centerCountsSize); closestCenters = palloc(closestCentersSize); lowerBound = palloc_extended(lowerBoundSize, MCXT_ALLOC_HUGE); @@ -489,24 +368,6 @@ ElkanKmeans(Relation index, VectorArray samples, VectorArray centers, KmeansStat newCenters = VectorArrayInit(numCenters, dimensions, centers->itemsize); newCenters->length = numCenters; - for (int j = 0; j < numCenters; j++) - kmeansstate->initCenter(VectorArrayGet(newCenters, j), dimensions); - - /* Initialize agg centers */ - if (!kmeansstate->separateAgg) - { - /* Use same centers to save memory */ - aggCenters = newCenters; - } - else - { - aggCenters = VectorArrayInit(numCenters, dimensions, VECTOR_SIZE(dimensions)); - aggCenters->length = numCenters; - - for (int j = 0; j < numCenters; j++) - VectorInitCenter(VectorArrayGet(aggCenters, j), dimensions); - } - #ifdef IVFFLAT_MEMORY ShowMemoryUsage(oldCtx, totalSize); #endif @@ -645,7 +506,7 @@ ElkanKmeans(Relation index, VectorArray samples, VectorArray centers, KmeansStat } /* Step 4: For each center c, let m(c) be mean of all points assigned */ - ComputeNewCenters(samples, aggCenters, newCenters, centerCounts, closestCenters, normprocinfo, normalizeprocinfo, collation, kmeansstate); + ComputeNewCenters(samples, agg, newCenters, centerCounts, closestCenters, normprocinfo, normalizeprocinfo, collation, kmeansstate); /* Step 5 */ for (int j = 0; j < numCenters; j++) @@ -699,7 +560,7 @@ CheckCenters(Relation index, VectorArray centers, KmeansState * kmeansstate) for (int j = 0; j < centers->dim; j++) scratch[j] = 0; - kmeansstate->sumCenter(VectorArrayGet(centers, i), scratch); + SumCenter(kmeansstate->sumcenterprocinfo, VectorArrayGet(centers, i), scratch); for (int j = 0; j < centers->dim; j++) { @@ -711,18 +572,6 @@ CheckCenters(Relation index, VectorArray centers, KmeansState * kmeansstate) } } - if (kmeansstate->checkDuplicates) - { - /* Ensure no duplicate centers */ - SortVectorArray(centers, kmeansstate); - - for (int i = 1; i < centers->length; i++) - { - if (datumIsEqual(PointerGetDatum(VectorArrayGet(centers, i)), PointerGetDatum(VectorArrayGet(centers, i - 1)), false, -1)) - elog(ERROR, "Duplicate centers detected. Please report a bug."); - } - } - /* Ensure no zero vectors for cosine distance */ /* Check NORM_PROC instead of KMEANS_NORM_PROC */ normprocinfo = IvfflatOptionalProcInfo(index, IVFFLAT_NORM_PROC); @@ -743,37 +592,11 @@ CheckCenters(Relation index, VectorArray centers, KmeansState * kmeansstate) } static void -InitKmeansState(KmeansState * kmeansstate, IvfflatType type) +InitKmeansState(KmeansState * kmeansstate, Relation index, int dimensions) { - if (type == IVFFLAT_TYPE_VECTOR) - { - kmeansstate->initCenter = VectorInitCenter; - kmeansstate->updateCenter = VectorUpdateCenter; - kmeansstate->sumCenter = VectorSumCenter; - kmeansstate->comp = CompareVectors; - kmeansstate->separateAgg = false; - kmeansstate->checkDuplicates = true; - } - else if (type == IVFFLAT_TYPE_HALFVEC) - { - kmeansstate->initCenter = HalfvecInitCenter; - kmeansstate->updateCenter = HalfvecUpdateCenter; - kmeansstate->sumCenter = HalfvecSumCenter; - kmeansstate->comp = CompareHalfVectors; - kmeansstate->separateAgg = true; - kmeansstate->checkDuplicates = true; - } - else if (type == IVFFLAT_TYPE_BIT) - { - kmeansstate->initCenter = BitInitCenter; - kmeansstate->updateCenter = BitUpdateCenter; - kmeansstate->sumCenter = BitSumCenter; - kmeansstate->comp = CompareBitVectors; - kmeansstate->separateAgg = true; - kmeansstate->checkDuplicates = false; - } - else - elog(ERROR, "Unsupported type"); + kmeansstate->dimensions = dimensions; + kmeansstate->updatecenterprocinfo = IvfflatOptionalProcInfo(index, IVFFLAT_UPDATE_CENTER_PROC); + kmeansstate->sumcenterprocinfo = IvfflatOptionalProcInfo(index, IVFFLAT_SUM_CENTER_PROC); } /* @@ -781,14 +604,14 @@ InitKmeansState(KmeansState * kmeansstate, IvfflatType type) * We use spherical k-means for inner product and cosine */ void -IvfflatKmeans(Relation index, VectorArray samples, VectorArray centers, IvfflatType type) +IvfflatKmeans(Relation index, VectorArray samples, VectorArray centers) { KmeansState kmeansstate; - InitKmeansState(&kmeansstate, type); + InitKmeansState(&kmeansstate, index, centers->dim); - if (samples->length <= centers->maxlen) - QuickCenters(index, samples, centers, &kmeansstate); + if (samples->length == 0) + RandomCenters(index, centers, &kmeansstate); else ElkanKmeans(index, samples, centers, &kmeansstate); diff --git a/src/ivfutils.c b/src/ivfutils.c index 9b5151d..95e2697 100644 --- a/src/ivfutils.c +++ b/src/ivfutils.c @@ -1,8 +1,11 @@ #include "postgres.h" #include "access/generic_xlog.h" +#include "bitvec.h" #include "catalog/pg_type.h" #include "fmgr.h" +#include "halfutils.h" +#include "halfvec.h" #include "ivfflat.h" #include "storage/bufmgr.h" @@ -239,21 +242,96 @@ ivfflat_bit_max_dims(PG_FUNCTION_ARGS) PG_RETURN_INT32(IVFFLAT_MAX_DIM * 32); }; -PGDLLEXPORT PG_FUNCTION_INFO_V1(ivfflat_halfvec_support); +PGDLLEXPORT PG_FUNCTION_INFO_V1(ivfflat_vector_update_center); Datum -ivfflat_halfvec_support(PG_FUNCTION_ARGS) +ivfflat_vector_update_center(PG_FUNCTION_ARGS) { - PG_RETURN_INT32(IVFFLAT_TYPE_HALFVEC); + Vector *vec = PG_GETARG_VECTOR_P(0); + int dimensions = PG_GETARG_INT32(1); + float *x = (float *) PG_GETARG_POINTER(2); + + SET_VARSIZE(vec, VECTOR_SIZE(dimensions)); + vec->dim = dimensions; + + for (int k = 0; k < dimensions; k++) + vec->x[k] = x[k]; + + PG_RETURN_VOID(); }; -PGDLLEXPORT PG_FUNCTION_INFO_V1(ivfflat_bit_support); +PGDLLEXPORT PG_FUNCTION_INFO_V1(ivfflat_halfvec_update_center); Datum -ivfflat_bit_support(PG_FUNCTION_ARGS) +ivfflat_halfvec_update_center(PG_FUNCTION_ARGS) { - Oid typid = PG_GETARG_OID(0); + HalfVector *vec = PG_GETARG_HALFVEC_P(0); + int dimensions = PG_GETARG_INT32(1); + float *x = (float *) PG_GETARG_POINTER(2); - if (typid == BITOID) - PG_RETURN_INT32(IVFFLAT_TYPE_BIT); - else - PG_RETURN_INT32(IVFFLAT_TYPE_UNSUPPORTED); + SET_VARSIZE(vec, HALFVEC_SIZE(dimensions)); + vec->dim = dimensions; + + for (int k = 0; k < dimensions; k++) + vec->x[k] = Float4ToHalfUnchecked(x[k]); + + PG_RETURN_VOID(); }; + +PGDLLEXPORT PG_FUNCTION_INFO_V1(ivfflat_bit_update_center); +Datum +ivfflat_bit_update_center(PG_FUNCTION_ARGS) +{ + VarBit *vec = PG_GETARG_VARBIT_P(0); + int dimensions = PG_GETARG_INT32(1); + float *x = (float *) PG_GETARG_POINTER(2); + unsigned char *nx = VARBITS(vec); + + SET_VARSIZE(vec, VARBITTOTALLEN(dimensions)); + VARBITLEN(vec) = dimensions; + + for (uint32 k = 0; k < VARBITBYTES(vec); k++) + nx[k] = 0; + + for (int k = 0; k < dimensions; k++) + nx[k / 8] |= (x[k] > 0.5 ? 1 : 0) << (7 - (k % 8)); + + PG_RETURN_VOID(); +}; + +PGDLLEXPORT PG_FUNCTION_INFO_V1(ivfflat_vector_sum_center); +Datum +ivfflat_vector_sum_center(PG_FUNCTION_ARGS) +{ + Vector *vec = PG_GETARG_VECTOR_P(0); + float *x = (float *) PG_GETARG_POINTER(1); + + for (int k = 0; k < vec->dim; k++) + x[k] += vec->x[k]; + + PG_RETURN_VOID(); +}; + +PGDLLEXPORT PG_FUNCTION_INFO_V1(ivfflat_halfvec_sum_center); +Datum +ivfflat_halfvec_sum_center(PG_FUNCTION_ARGS) +{ + HalfVector *vec = PG_GETARG_HALFVEC_P(0); + float *x = (float *) PG_GETARG_POINTER(1); + + for (int k = 0; k < vec->dim; k++) + x[k] += HalfToFloat4(vec->x[k]); + + PG_RETURN_VOID(); +} + +PGDLLEXPORT PG_FUNCTION_INFO_V1(ivfflat_bit_sum_center); +Datum +ivfflat_bit_sum_center(PG_FUNCTION_ARGS) +{ + VarBit *vec = PG_GETARG_VARBIT_P(0); + float *x = (float *) PG_GETARG_POINTER(1); + + for (int k = 0; k < VARBITLEN(vec); k++) + x[k] += (float) (((VARBITS(vec)[k / 8]) >> (7 - (k % 8))) & 0x01); + + PG_RETURN_VOID(); +}