Reduced support functions for ivfflat - #527

This commit is contained in:
Andrew Kane
2024-04-25 11:49:48 -07:00
parent c67dc6f9b0
commit e9c3c42e1c
7 changed files with 126 additions and 234 deletions

View File

@@ -22,28 +22,10 @@ CREATE OPERATOR || (
LEFTARG = vector, RIGHTARG = vector, PROCEDURE = vector_concat LEFTARG = vector, RIGHTARG = vector, PROCEDURE = vector_concat
); );
CREATE FUNCTION ivfflat_bit_max_dims(internal) RETURNS internal CREATE FUNCTION ivfflat_bit_support(internal) RETURNS internal
AS 'MODULE_PATHNAME' LANGUAGE C; AS 'MODULE_PATHNAME' LANGUAGE C;
CREATE FUNCTION ivfflat_halfvec_max_dims(internal) RETURNS internal CREATE FUNCTION ivfflat_halfvec_support(internal) RETURNS internal
AS 'MODULE_PATHNAME' LANGUAGE C;
CREATE FUNCTION ivfflat_vector_update_center(internal, internal, internal) RETURNS internal
AS 'MODULE_PATHNAME' LANGUAGE C;
CREATE FUNCTION ivfflat_bit_update_center(internal, internal, internal) RETURNS internal
AS 'MODULE_PATHNAME' LANGUAGE C;
CREATE FUNCTION ivfflat_halfvec_update_center(internal, internal, internal) RETURNS internal
AS 'MODULE_PATHNAME' LANGUAGE C;
CREATE FUNCTION ivfflat_vector_sum_center(internal, internal) RETURNS internal
AS 'MODULE_PATHNAME' LANGUAGE C;
CREATE FUNCTION ivfflat_bit_sum_center(internal, internal) RETURNS internal
AS 'MODULE_PATHNAME' LANGUAGE C;
CREATE FUNCTION ivfflat_halfvec_sum_center(internal, internal) RETURNS internal
AS 'MODULE_PATHNAME' LANGUAGE C; AS 'MODULE_PATHNAME' LANGUAGE C;
CREATE FUNCTION hnsw_bit_max_dims(internal) RETURNS internal CREATE FUNCTION hnsw_bit_max_dims(internal) RETURNS internal
@@ -84,9 +66,7 @@ CREATE OPERATOR CLASS bit_hamming_ops
OPERATOR 1 <~> (bit, bit) FOR ORDER BY float_ops, OPERATOR 1 <~> (bit, bit) FOR ORDER BY float_ops,
FUNCTION 1 hamming_distance(bit, bit), FUNCTION 1 hamming_distance(bit, bit),
FUNCTION 3 hamming_distance(bit, bit), FUNCTION 3 hamming_distance(bit, bit),
FUNCTION 6 ivfflat_bit_max_dims(internal), FUNCTION 6 ivfflat_bit_support(internal);
FUNCTION 7 ivfflat_bit_update_center(internal, internal, internal),
FUNCTION 8 ivfflat_bit_sum_center(internal, internal);
CREATE OPERATOR CLASS bit_hamming_ops CREATE OPERATOR CLASS bit_hamming_ops
FOR TYPE bit USING hnsw AS FOR TYPE bit USING hnsw AS
@@ -353,9 +333,7 @@ CREATE OPERATOR CLASS halfvec_l2_ops
OPERATOR 1 <-> (halfvec, halfvec) FOR ORDER BY float_ops, OPERATOR 1 <-> (halfvec, halfvec) FOR ORDER BY float_ops,
FUNCTION 1 halfvec_l2_squared_distance(halfvec, halfvec), FUNCTION 1 halfvec_l2_squared_distance(halfvec, halfvec),
FUNCTION 3 l2_distance(halfvec, halfvec), FUNCTION 3 l2_distance(halfvec, halfvec),
FUNCTION 6 ivfflat_halfvec_max_dims(internal), FUNCTION 6 ivfflat_halfvec_support(internal);
FUNCTION 7 ivfflat_halfvec_update_center(internal, internal, internal),
FUNCTION 8 ivfflat_halfvec_sum_center(internal, internal);
CREATE OPERATOR CLASS halfvec_ip_ops CREATE OPERATOR CLASS halfvec_ip_ops
FOR TYPE halfvec USING ivfflat AS FOR TYPE halfvec USING ivfflat AS
@@ -364,9 +342,7 @@ CREATE OPERATOR CLASS halfvec_ip_ops
FUNCTION 3 halfvec_spherical_distance(halfvec, halfvec), FUNCTION 3 halfvec_spherical_distance(halfvec, halfvec),
FUNCTION 4 l2_norm(halfvec), FUNCTION 4 l2_norm(halfvec),
FUNCTION 5 l2_normalize(halfvec), FUNCTION 5 l2_normalize(halfvec),
FUNCTION 6 ivfflat_halfvec_max_dims(internal), FUNCTION 6 ivfflat_halfvec_support(internal);
FUNCTION 7 ivfflat_halfvec_update_center(internal, internal, internal),
FUNCTION 8 ivfflat_halfvec_sum_center(internal, internal);
CREATE OPERATOR CLASS halfvec_cosine_ops CREATE OPERATOR CLASS halfvec_cosine_ops
FOR TYPE halfvec USING ivfflat AS FOR TYPE halfvec USING ivfflat AS
@@ -376,9 +352,7 @@ CREATE OPERATOR CLASS halfvec_cosine_ops
FUNCTION 3 halfvec_spherical_distance(halfvec, halfvec), FUNCTION 3 halfvec_spherical_distance(halfvec, halfvec),
FUNCTION 4 l2_norm(halfvec), FUNCTION 4 l2_norm(halfvec),
FUNCTION 5 l2_normalize(halfvec), FUNCTION 5 l2_normalize(halfvec),
FUNCTION 6 ivfflat_halfvec_max_dims(internal), FUNCTION 6 ivfflat_halfvec_support(internal);
FUNCTION 7 ivfflat_halfvec_update_center(internal, internal, internal),
FUNCTION 8 ivfflat_halfvec_sum_center(internal, internal);
CREATE OPERATOR CLASS halfvec_l2_ops CREATE OPERATOR CLASS halfvec_l2_ops
FOR TYPE halfvec USING hnsw AS FOR TYPE halfvec USING hnsw AS

View File

@@ -263,28 +263,10 @@ COMMENT ON ACCESS METHOD hnsw IS 'hnsw index access method';
-- access method private functions -- access method private functions
CREATE FUNCTION ivfflat_bit_max_dims(internal) RETURNS internal CREATE FUNCTION ivfflat_bit_support(internal) RETURNS internal
AS 'MODULE_PATHNAME' LANGUAGE C; AS 'MODULE_PATHNAME' LANGUAGE C;
CREATE FUNCTION ivfflat_halfvec_max_dims(internal) RETURNS internal CREATE FUNCTION ivfflat_halfvec_support(internal) RETURNS internal
AS 'MODULE_PATHNAME' LANGUAGE C;
CREATE FUNCTION ivfflat_vector_update_center(internal, internal, internal) RETURNS internal
AS 'MODULE_PATHNAME' LANGUAGE C;
CREATE FUNCTION ivfflat_bit_update_center(internal, internal, internal) RETURNS internal
AS 'MODULE_PATHNAME' LANGUAGE C;
CREATE FUNCTION ivfflat_halfvec_update_center(internal, internal, internal) RETURNS internal
AS 'MODULE_PATHNAME' LANGUAGE C;
CREATE FUNCTION ivfflat_vector_sum_center(internal, internal) RETURNS internal
AS 'MODULE_PATHNAME' LANGUAGE C;
CREATE FUNCTION ivfflat_bit_sum_center(internal, internal) RETURNS internal
AS 'MODULE_PATHNAME' LANGUAGE C;
CREATE FUNCTION ivfflat_halfvec_sum_center(internal, internal) RETURNS internal
AS 'MODULE_PATHNAME' LANGUAGE C; AS 'MODULE_PATHNAME' LANGUAGE C;
CREATE FUNCTION hnsw_bit_max_dims(internal) RETURNS internal CREATE FUNCTION hnsw_bit_max_dims(internal) RETURNS internal
@@ -379,9 +361,7 @@ CREATE OPERATOR CLASS bit_hamming_ops
OPERATOR 1 <~> (bit, bit) FOR ORDER BY float_ops, OPERATOR 1 <~> (bit, bit) FOR ORDER BY float_ops,
FUNCTION 1 hamming_distance(bit, bit), FUNCTION 1 hamming_distance(bit, bit),
FUNCTION 3 hamming_distance(bit, bit), FUNCTION 3 hamming_distance(bit, bit),
FUNCTION 6 ivfflat_bit_max_dims(internal), FUNCTION 6 ivfflat_bit_support(internal);
FUNCTION 7 ivfflat_bit_update_center(internal, internal, internal),
FUNCTION 8 ivfflat_bit_sum_center(internal, internal);
CREATE OPERATOR CLASS bit_hamming_ops CREATE OPERATOR CLASS bit_hamming_ops
FOR TYPE bit USING hnsw AS FOR TYPE bit USING hnsw AS
@@ -664,9 +644,7 @@ CREATE OPERATOR CLASS halfvec_l2_ops
OPERATOR 1 <-> (halfvec, halfvec) FOR ORDER BY float_ops, OPERATOR 1 <-> (halfvec, halfvec) FOR ORDER BY float_ops,
FUNCTION 1 halfvec_l2_squared_distance(halfvec, halfvec), FUNCTION 1 halfvec_l2_squared_distance(halfvec, halfvec),
FUNCTION 3 l2_distance(halfvec, halfvec), FUNCTION 3 l2_distance(halfvec, halfvec),
FUNCTION 6 ivfflat_halfvec_max_dims(internal), FUNCTION 6 ivfflat_halfvec_support(internal);
FUNCTION 7 ivfflat_halfvec_update_center(internal, internal, internal),
FUNCTION 8 ivfflat_halfvec_sum_center(internal, internal);
CREATE OPERATOR CLASS halfvec_ip_ops CREATE OPERATOR CLASS halfvec_ip_ops
FOR TYPE halfvec USING ivfflat AS FOR TYPE halfvec USING ivfflat AS
@@ -675,9 +653,7 @@ CREATE OPERATOR CLASS halfvec_ip_ops
FUNCTION 3 halfvec_spherical_distance(halfvec, halfvec), FUNCTION 3 halfvec_spherical_distance(halfvec, halfvec),
FUNCTION 4 l2_norm(halfvec), FUNCTION 4 l2_norm(halfvec),
FUNCTION 5 l2_normalize(halfvec), FUNCTION 5 l2_normalize(halfvec),
FUNCTION 6 ivfflat_halfvec_max_dims(internal), FUNCTION 6 ivfflat_halfvec_support(internal);
FUNCTION 7 ivfflat_halfvec_update_center(internal, internal, internal),
FUNCTION 8 ivfflat_halfvec_sum_center(internal, internal);
CREATE OPERATOR CLASS halfvec_cosine_ops CREATE OPERATOR CLASS halfvec_cosine_ops
FOR TYPE halfvec USING ivfflat AS FOR TYPE halfvec USING ivfflat AS
@@ -687,9 +663,7 @@ CREATE OPERATOR CLASS halfvec_cosine_ops
FUNCTION 3 halfvec_spherical_distance(halfvec, halfvec), FUNCTION 3 halfvec_spherical_distance(halfvec, halfvec),
FUNCTION 4 l2_norm(halfvec), FUNCTION 4 l2_norm(halfvec),
FUNCTION 5 l2_normalize(halfvec), FUNCTION 5 l2_normalize(halfvec),
FUNCTION 6 ivfflat_halfvec_max_dims(internal), FUNCTION 6 ivfflat_halfvec_support(internal);
FUNCTION 7 ivfflat_halfvec_update_center(internal, internal, internal),
FUNCTION 8 ivfflat_halfvec_sum_center(internal, internal);
CREATE OPERATOR CLASS halfvec_l2_ops CREATE OPERATOR CLASS halfvec_l2_ops
FOR TYPE halfvec USING hnsw AS FOR TYPE halfvec USING hnsw AS

View File

@@ -319,44 +319,13 @@ InsertTuples(Relation index, IvfflatBuildState * buildstate, ForkNumber forkNum)
} }
} }
/*
* Get max dimensions
*/
static int
GetMaxDimensions(Relation index)
{
FmgrInfo *procinfo = IvfflatOptionalProcInfo(index, IVFFLAT_MAX_DIMS_PROC);
if (procinfo == NULL)
return IVFFLAT_MAX_DIM;
return DatumGetInt32(FunctionCall1(procinfo, PointerGetDatum(NULL)));
}
/*
* Get item size
*/
static Size
GetItemSize(int maxDimensions, int dimensions)
{
/* TODO Improve */
if (maxDimensions == IVFFLAT_MAX_DIM)
return VECTOR_SIZE(dimensions);
else if (maxDimensions == IVFFLAT_MAX_DIM * 2)
return HALFVEC_SIZE(dimensions);
else if (maxDimensions == IVFFLAT_MAX_DIM * 32)
return VARBITTOTALLEN(dimensions);
else
elog(ERROR, "Unsupported type");
}
/* /*
* Initialize the build state * Initialize the build state
*/ */
static void static void
InitBuildState(IvfflatBuildState * buildstate, Relation heap, Relation index, IndexInfo *indexInfo) InitBuildState(IvfflatBuildState * buildstate, Relation heap, Relation index, IndexInfo *indexInfo)
{ {
int maxDimensions; IvfflatTypeInfo *typeInfo = &buildstate->typeInfo;
buildstate->heap = heap; buildstate->heap = heap;
buildstate->index = index; buildstate->index = index;
@@ -365,18 +334,19 @@ InitBuildState(IvfflatBuildState * buildstate, Relation heap, Relation index, In
buildstate->lists = IvfflatGetLists(index); buildstate->lists = IvfflatGetLists(index);
buildstate->dimensions = TupleDescAttr(index->rd_att, 0)->atttypmod; buildstate->dimensions = TupleDescAttr(index->rd_att, 0)->atttypmod;
typeInfo->dimensions = buildstate->dimensions;
GetTypeInfo(typeInfo, index);
/* Disallow varbit since require fixed dimensions */ /* Disallow varbit since require fixed dimensions */
if (TupleDescAttr(index->rd_att, 0)->atttypid == VARBITOID) if (TupleDescAttr(index->rd_att, 0)->atttypid == VARBITOID)
elog(ERROR, "type not supported for ivfflat index"); elog(ERROR, "type not supported for ivfflat index");
maxDimensions = GetMaxDimensions(index);
/* Require column to have dimensions to be indexed */ /* Require column to have dimensions to be indexed */
if (buildstate->dimensions < 0) if (buildstate->dimensions < 0)
elog(ERROR, "column does not have dimensions"); elog(ERROR, "column does not have dimensions");
if (buildstate->dimensions > maxDimensions) if (buildstate->dimensions > typeInfo->maxDimensions)
elog(ERROR, "column cannot have more than %d dimensions for ivfflat index", maxDimensions); elog(ERROR, "column cannot have more than %d dimensions for ivfflat index", typeInfo->maxDimensions);
buildstate->reltuples = 0; buildstate->reltuples = 0;
buildstate->indtuples = 0; buildstate->indtuples = 0;
@@ -400,7 +370,7 @@ InitBuildState(IvfflatBuildState * buildstate, Relation heap, Relation index, In
buildstate->slot = MakeSingleTupleTableSlot(buildstate->tupdesc, &TTSOpsVirtual); buildstate->slot = MakeSingleTupleTableSlot(buildstate->tupdesc, &TTSOpsVirtual);
buildstate->centers = VectorArrayInit(buildstate->lists, buildstate->dimensions, GetItemSize(maxDimensions, buildstate->dimensions)); buildstate->centers = VectorArrayInit(buildstate->lists, buildstate->dimensions, typeInfo->itemsize);
buildstate->listInfo = palloc(sizeof(ListInfo) * buildstate->lists); buildstate->listInfo = palloc(sizeof(ListInfo) * buildstate->lists);
buildstate->tmpCtx = AllocSetContextCreate(CurrentMemoryContext, buildstate->tmpCtx = AllocSetContextCreate(CurrentMemoryContext,
@@ -470,7 +440,7 @@ ComputeCenters(IvfflatBuildState * buildstate)
} }
/* Calculate centers */ /* Calculate centers */
IvfflatBench("k-means", IvfflatKmeans(buildstate->index, buildstate->samples, buildstate->centers)); IvfflatBench("k-means", IvfflatKmeans(buildstate->index, buildstate->samples, buildstate->centers, &buildstate->typeInfo));
/* Free samples before we allocate more memory */ /* Free samples before we allocate more memory */
VectorArrayFree(buildstate->samples); VectorArrayFree(buildstate->samples);

View File

@@ -188,7 +188,7 @@ ivfflathandler(PG_FUNCTION_ARGS)
IndexAmRoutine *amroutine = makeNode(IndexAmRoutine); IndexAmRoutine *amroutine = makeNode(IndexAmRoutine);
amroutine->amstrategies = 0; amroutine->amstrategies = 0;
amroutine->amsupport = 8; amroutine->amsupport = 6;
#if PG_VERSION_NUM >= 130000 #if PG_VERSION_NUM >= 130000
amroutine->amoptsprocnum = 0; amroutine->amoptsprocnum = 0;
#endif #endif

View File

@@ -29,9 +29,7 @@
#define IVFFLAT_KMEANS_DISTANCE_PROC 3 #define IVFFLAT_KMEANS_DISTANCE_PROC 3
#define IVFFLAT_KMEANS_NORM_PROC 4 #define IVFFLAT_KMEANS_NORM_PROC 4
#define IVFFLAT_NORMALIZE_PROC 5 #define IVFFLAT_NORMALIZE_PROC 5
#define IVFFLAT_MAX_DIMS_PROC 6 #define IVFFLAT_TYPE_INFO_PROC 6
#define IVFFLAT_UPDATE_CENTER_PROC 7
#define IVFFLAT_SUM_CENTER_PROC 8
#define IVFFLAT_VERSION 1 #define IVFFLAT_VERSION 1
#define IVFFLAT_MAGIC_NUMBER 0x14FF1A7 #define IVFFLAT_MAGIC_NUMBER 0x14FF1A7
@@ -152,12 +150,22 @@ typedef struct IvfflatLeader
char *ivfcenters; char *ivfcenters;
} IvfflatLeader; } IvfflatLeader;
typedef struct IvfflatTypeInfo
{
int dimensions;
int maxDimensions;
int itemsize;
void (*updateCenter) (Pointer v, int dimensions, float *x);
void (*sumCenter) (Pointer v, float *x);
} IvfflatTypeInfo;
typedef struct IvfflatBuildState typedef struct IvfflatBuildState
{ {
/* Info */ /* Info */
Relation heap; Relation heap;
Relation index; Relation index;
IndexInfo *indexInfo; IndexInfo *indexInfo;
IvfflatTypeInfo typeInfo;
/* Settings */ /* Settings */
int dimensions; int dimensions;
@@ -271,7 +279,7 @@ typedef IvfflatScanOpaqueData * IvfflatScanOpaque;
/* Methods */ /* Methods */
VectorArray VectorArrayInit(int maxlen, int dimensions, Size itemsize); VectorArray VectorArrayInit(int maxlen, int dimensions, Size itemsize);
void VectorArrayFree(VectorArray arr); void VectorArrayFree(VectorArray arr);
void IvfflatKmeans(Relation index, VectorArray samples, VectorArray centers); void IvfflatKmeans(Relation index, VectorArray samples, VectorArray centers, IvfflatTypeInfo * typeInfo);
FmgrInfo *IvfflatOptionalProcInfo(Relation index, uint16 procnum); FmgrInfo *IvfflatOptionalProcInfo(Relation index, uint16 procnum);
Datum IvfflatNormValue(FmgrInfo *procinfo, Oid collation, Datum value); Datum IvfflatNormValue(FmgrInfo *procinfo, Oid collation, Datum value);
bool IvfflatCheckNorm(FmgrInfo *procinfo, Oid collation, Datum value); bool IvfflatCheckNorm(FmgrInfo *procinfo, Oid collation, Datum value);
@@ -284,6 +292,7 @@ Buffer IvfflatNewBuffer(Relation index, ForkNumber forkNum);
void IvfflatInitPage(Buffer buf, Page page); void IvfflatInitPage(Buffer buf, Page page);
void IvfflatInitRegisterPage(Relation index, Buffer *buf, Page *page, GenericXLogState **state); void IvfflatInitRegisterPage(Relation index, Buffer *buf, Page *page, GenericXLogState **state);
void IvfflatInit(void); void IvfflatInit(void);
void GetTypeInfo(IvfflatTypeInfo * typeInfo, Relation index);
PGDLLEXPORT void IvfflatParallelBuildMain(dsm_segment *seg, shm_toc *toc); PGDLLEXPORT void IvfflatParallelBuildMain(dsm_segment *seg, shm_toc *toc);
/* Index access methods */ /* Index access methods */

View File

@@ -13,17 +13,6 @@
#include "utils/memutils.h" #include "utils/memutils.h"
#include "vector.h" #include "vector.h"
/* Support functions */
PGDLLEXPORT Datum ivfflat_vector_update_center(PG_FUNCTION_ARGS);
PGDLLEXPORT Datum ivfflat_vector_sum_center(PG_FUNCTION_ARGS);
typedef struct KmeansState
{
int dimensions;
FmgrInfo *updatecenterprocinfo;
FmgrInfo *sumcenterprocinfo;
} KmeansState;
/* /*
* Initialize with kmeans++ * Initialize with kmeans++
* *
@@ -127,20 +116,11 @@ NormCenters(FmgrInfo *normalizeprocinfo, Oid collation, VectorArray centers)
MemoryContextDelete(normCtx); MemoryContextDelete(normCtx);
} }
static void
UpdateCenter(FmgrInfo *procinfo, Pointer center, int dimensions, float *x)
{
if (procinfo == NULL)
DirectFunctionCall3(ivfflat_vector_update_center, PointerGetDatum(center), Int32GetDatum(dimensions), PointerGetDatum(x));
else
FunctionCall3(procinfo, PointerGetDatum(center), Int32GetDatum(dimensions), PointerGetDatum(x));
}
/* /*
* Quick approach if we have no data * Quick approach if we have no data
*/ */
static void static void
RandomCenters(Relation index, VectorArray centers, KmeansState * kmeansstate) RandomCenters(Relation index, VectorArray centers, IvfflatTypeInfo * typeInfo)
{ {
int dimensions = centers->dim; int dimensions = centers->dim;
Oid collation = index->rd_indcollation[0]; Oid collation = index->rd_indcollation[0];
@@ -156,7 +136,7 @@ RandomCenters(Relation index, VectorArray centers, KmeansState * kmeansstate)
for (int i = 0; i < dimensions; i++) for (int i = 0; i < dimensions; i++)
x[i] = (float) RandomDouble(); x[i] = (float) RandomDouble();
UpdateCenter(kmeansstate->updatecenterprocinfo, center, dimensions, x); typeInfo->updateCenter(center, dimensions, x);
centers->length++; centers->length++;
} }
@@ -184,26 +164,17 @@ ShowMemoryUsage(MemoryContext context, Size estimatedSize)
} }
#endif #endif
static void
SumCenter(FmgrInfo *procinfo, Pointer sample, float *x)
{
if (procinfo == NULL)
DirectFunctionCall2(ivfflat_vector_sum_center, PointerGetDatum(sample), PointerGetDatum(x));
else
FunctionCall2(procinfo, PointerGetDatum(sample), PointerGetDatum(x));
}
/* /*
* Sum centers * Sum centers
*/ */
static void static void
SumCenters(VectorArray samples, float *agg, int *closestCenters, KmeansState * kmeansstate) SumCenters(VectorArray samples, float *agg, int *closestCenters, IvfflatTypeInfo * typeInfo)
{ {
for (int j = 0; j < samples->length; j++) for (int j = 0; j < samples->length; j++)
{ {
float *x = agg + ((int64) closestCenters[j] * kmeansstate->dimensions); float *x = agg + ((int64) closestCenters[j] * samples->dim);
SumCenter(kmeansstate->sumcenterprocinfo, VectorArrayGet(samples, j), x); typeInfo->sumCenter(VectorArrayGet(samples, j), x);
} }
} }
@@ -211,13 +182,13 @@ SumCenters(VectorArray samples, float *agg, int *closestCenters, KmeansState * k
* Update centers * Update centers
*/ */
static void static void
UpdateCenters(float *agg, VectorArray centers, KmeansState * kmeansstate) UpdateCenters(float *agg, VectorArray centers, IvfflatTypeInfo * typeInfo)
{ {
for (int j = 0; j < centers->length; j++) for (int j = 0; j < centers->length; j++)
{ {
float *x = agg + ((int64) j * kmeansstate->dimensions); float *x = agg + ((int64) j * centers->dim);
UpdateCenter(kmeansstate->updatecenterprocinfo, VectorArrayGet(centers, j), centers->dim, x); typeInfo->updateCenter(VectorArrayGet(centers, j), centers->dim, x);
} }
} }
@@ -225,9 +196,9 @@ UpdateCenters(float *agg, VectorArray centers, KmeansState * kmeansstate)
* Compute new centers * Compute new centers
*/ */
static void static void
ComputeNewCenters(VectorArray samples, float *agg, VectorArray newCenters, int *centerCounts, int *closestCenters, FmgrInfo *normprocinfo, FmgrInfo *normalizeprocinfo, Oid collation, KmeansState * kmeansstate) ComputeNewCenters(VectorArray samples, float *agg, VectorArray newCenters, int *centerCounts, int *closestCenters, FmgrInfo *normprocinfo, FmgrInfo *normalizeprocinfo, Oid collation, IvfflatTypeInfo * typeInfo)
{ {
int dimensions = kmeansstate->dimensions; int dimensions = newCenters->dim;
int numCenters = newCenters->length; int numCenters = newCenters->length;
int numSamples = samples->length; int numSamples = samples->length;
@@ -243,7 +214,7 @@ ComputeNewCenters(VectorArray samples, float *agg, VectorArray newCenters, int *
} }
/* Increment sum of closest center */ /* Increment sum of closest center */
SumCenters(samples, agg, closestCenters, kmeansstate); SumCenters(samples, agg, closestCenters, typeInfo);
/* Increment count of closest center */ /* Increment count of closest center */
for (int j = 0; j < numSamples; j++) for (int j = 0; j < numSamples; j++)
@@ -276,7 +247,7 @@ ComputeNewCenters(VectorArray samples, float *agg, VectorArray newCenters, int *
} }
/* Set new centers */ /* Set new centers */
UpdateCenters(agg, newCenters, kmeansstate); UpdateCenters(agg, newCenters, typeInfo);
/* Normalize if needed */ /* Normalize if needed */
if (normprocinfo != NULL) if (normprocinfo != NULL)
@@ -292,7 +263,7 @@ ComputeNewCenters(VectorArray samples, float *agg, VectorArray newCenters, int *
* https://www.aaai.org/Papers/ICML/2003/ICML03-022.pdf * https://www.aaai.org/Papers/ICML/2003/ICML03-022.pdf
*/ */
static void static void
ElkanKmeans(Relation index, VectorArray samples, VectorArray centers, KmeansState * kmeansstate) ElkanKmeans(Relation index, VectorArray samples, VectorArray centers, IvfflatTypeInfo * typeInfo)
{ {
FmgrInfo *procinfo; FmgrInfo *procinfo;
FmgrInfo *normprocinfo; FmgrInfo *normprocinfo;
@@ -506,7 +477,7 @@ ElkanKmeans(Relation index, VectorArray samples, VectorArray centers, KmeansStat
} }
/* Step 4: For each center c, let m(c) be mean of all points assigned */ /* Step 4: For each center c, let m(c) be mean of all points assigned */
ComputeNewCenters(samples, agg, newCenters, centerCounts, closestCenters, normprocinfo, normalizeprocinfo, collation, kmeansstate); ComputeNewCenters(samples, agg, newCenters, centerCounts, closestCenters, normprocinfo, normalizeprocinfo, collation, typeInfo);
/* Step 5 */ /* Step 5 */
for (int j = 0; j < numCenters; j++) for (int j = 0; j < numCenters; j++)
@@ -546,7 +517,7 @@ ElkanKmeans(Relation index, VectorArray samples, VectorArray centers, KmeansStat
* Detect issues with centers * Detect issues with centers
*/ */
static void static void
CheckCenters(Relation index, VectorArray centers, KmeansState * kmeansstate) CheckCenters(Relation index, VectorArray centers, IvfflatTypeInfo * typeInfo)
{ {
FmgrInfo *normprocinfo; FmgrInfo *normprocinfo;
float *scratch = palloc(sizeof(float) * centers->dim); float *scratch = palloc(sizeof(float) * centers->dim);
@@ -560,7 +531,7 @@ CheckCenters(Relation index, VectorArray centers, KmeansState * kmeansstate)
for (int j = 0; j < centers->dim; j++) for (int j = 0; j < centers->dim; j++)
scratch[j] = 0; scratch[j] = 0;
SumCenter(kmeansstate->sumcenterprocinfo, VectorArrayGet(centers, i), scratch); typeInfo->sumCenter(VectorArrayGet(centers, i), scratch);
for (int j = 0; j < centers->dim; j++) for (int j = 0; j < centers->dim; j++)
{ {
@@ -591,29 +562,17 @@ CheckCenters(Relation index, VectorArray centers, KmeansState * kmeansstate)
pfree(scratch); pfree(scratch);
} }
static void
InitKmeansState(KmeansState * kmeansstate, Relation index, int dimensions)
{
kmeansstate->dimensions = dimensions;
kmeansstate->updatecenterprocinfo = IvfflatOptionalProcInfo(index, IVFFLAT_UPDATE_CENTER_PROC);
kmeansstate->sumcenterprocinfo = IvfflatOptionalProcInfo(index, IVFFLAT_SUM_CENTER_PROC);
}
/* /*
* Perform naive k-means centering * Perform naive k-means centering
* We use spherical k-means for inner product and cosine * We use spherical k-means for inner product and cosine
*/ */
void void
IvfflatKmeans(Relation index, VectorArray samples, VectorArray centers) IvfflatKmeans(Relation index, VectorArray samples, VectorArray centers, IvfflatTypeInfo * typeInfo)
{ {
KmeansState kmeansstate;
InitKmeansState(&kmeansstate, index, centers->dim);
if (samples->length == 0) if (samples->length == 0)
RandomCenters(index, centers, &kmeansstate); RandomCenters(index, centers, typeInfo);
else else
ElkanKmeans(index, samples, centers, &kmeansstate); ElkanKmeans(index, samples, centers, typeInfo);
CheckCenters(index, centers, &kmeansstate); CheckCenters(index, centers, typeInfo);
} }

View File

@@ -228,61 +228,34 @@ IvfflatUpdateList(Relation index, ListInfo listInfo,
} }
} }
PGDLLEXPORT PG_FUNCTION_INFO_V1(ivfflat_halfvec_max_dims); static void
Datum VectorUpdateCenter(Pointer v, int dimensions, float *x)
ivfflat_halfvec_max_dims(PG_FUNCTION_ARGS)
{ {
PG_RETURN_INT32(IVFFLAT_MAX_DIM * 2); Vector *vec = (Vector *) v;
};
PGDLLEXPORT PG_FUNCTION_INFO_V1(ivfflat_bit_max_dims);
Datum
ivfflat_bit_max_dims(PG_FUNCTION_ARGS)
{
PG_RETURN_INT32(IVFFLAT_MAX_DIM * 32);
};
PGDLLEXPORT PG_FUNCTION_INFO_V1(ivfflat_vector_update_center);
Datum
ivfflat_vector_update_center(PG_FUNCTION_ARGS)
{
Vector *vec = PG_GETARG_VECTOR_P(0);
int dimensions = PG_GETARG_INT32(1);
float *x = (float *) PG_GETARG_POINTER(2);
SET_VARSIZE(vec, VECTOR_SIZE(dimensions)); SET_VARSIZE(vec, VECTOR_SIZE(dimensions));
vec->dim = dimensions; vec->dim = dimensions;
for (int k = 0; k < dimensions; k++) for (int k = 0; k < dimensions; k++)
vec->x[k] = x[k]; vec->x[k] = x[k];
}
PG_RETURN_VOID(); static void
}; HalfvecUpdateCenter(Pointer v, int dimensions, float *x)
PGDLLEXPORT PG_FUNCTION_INFO_V1(ivfflat_halfvec_update_center);
Datum
ivfflat_halfvec_update_center(PG_FUNCTION_ARGS)
{ {
HalfVector *vec = PG_GETARG_HALFVEC_P(0); HalfVector *vec = (HalfVector *) v;
int dimensions = PG_GETARG_INT32(1);
float *x = (float *) PG_GETARG_POINTER(2);
SET_VARSIZE(vec, HALFVEC_SIZE(dimensions)); SET_VARSIZE(vec, HALFVEC_SIZE(dimensions));
vec->dim = dimensions; vec->dim = dimensions;
for (int k = 0; k < dimensions; k++) for (int k = 0; k < dimensions; k++)
vec->x[k] = Float4ToHalfUnchecked(x[k]); vec->x[k] = Float4ToHalfUnchecked(x[k]);
}
PG_RETURN_VOID(); static void
}; BitUpdateCenter(Pointer v, int dimensions, float *x)
PGDLLEXPORT PG_FUNCTION_INFO_V1(ivfflat_bit_update_center);
Datum
ivfflat_bit_update_center(PG_FUNCTION_ARGS)
{ {
VarBit *vec = PG_GETARG_VARBIT_P(0); VarBit *vec = (VarBit *) v;
int dimensions = PG_GETARG_INT32(1);
float *x = (float *) PG_GETARG_POINTER(2);
unsigned char *nx = VARBITS(vec); unsigned char *nx = VARBITS(vec);
SET_VARSIZE(vec, VARBITTOTALLEN(dimensions)); SET_VARSIZE(vec, VARBITTOTALLEN(dimensions));
@@ -293,45 +266,78 @@ ivfflat_bit_update_center(PG_FUNCTION_ARGS)
for (int k = 0; k < dimensions; k++) for (int k = 0; k < dimensions; k++)
nx[k / 8] |= (x[k] > 0.5 ? 1 : 0) << (7 - (k % 8)); nx[k / 8] |= (x[k] > 0.5 ? 1 : 0) << (7 - (k % 8));
}
PG_RETURN_VOID(); static void
}; VectorSumCenter(Pointer v, float *x)
PGDLLEXPORT PG_FUNCTION_INFO_V1(ivfflat_vector_sum_center);
Datum
ivfflat_vector_sum_center(PG_FUNCTION_ARGS)
{ {
Vector *vec = PG_GETARG_VECTOR_P(0); Vector *vec = (Vector *) v;
float *x = (float *) PG_GETARG_POINTER(1);
for (int k = 0; k < vec->dim; k++) for (int k = 0; k < vec->dim; k++)
x[k] += vec->x[k]; x[k] += vec->x[k];
}
static void
HalfvecSumCenter(Pointer v, float *x)
{
HalfVector *vec = (HalfVector *) v;
for (int k = 0; k < vec->dim; k++)
x[k] += HalfToFloat4(vec->x[k]);
}
static void
BitSumCenter(Pointer v, float *x)
{
VarBit *vec = (VarBit *) v;
for (int k = 0; k < VARBITLEN(vec); k++)
x[k] += (float) (((VARBITS(vec)[k / 8]) >> (7 - (k % 8))) & 0x01);
}
/*
* Get type info
*/
void
GetTypeInfo(IvfflatTypeInfo * typeInfo, Relation index)
{
FmgrInfo *procinfo = IvfflatOptionalProcInfo(index, IVFFLAT_TYPE_INFO_PROC);
if (procinfo == NULL)
{
typeInfo->maxDimensions = IVFFLAT_MAX_DIM;
typeInfo->itemsize = VECTOR_SIZE(typeInfo->dimensions);
typeInfo->updateCenter = VectorUpdateCenter;
typeInfo->sumCenter = VectorSumCenter;
}
else
FunctionCall1(procinfo, PointerGetDatum(typeInfo));
}
PGDLLEXPORT PG_FUNCTION_INFO_V1(ivfflat_halfvec_support);
Datum
ivfflat_halfvec_support(PG_FUNCTION_ARGS)
{
IvfflatTypeInfo *typeInfo = (IvfflatTypeInfo *) PG_GETARG_POINTER(0);
typeInfo->maxDimensions = IVFFLAT_MAX_DIM * 2;
typeInfo->itemsize = HALFVEC_SIZE(typeInfo->dimensions);
typeInfo->updateCenter = HalfvecUpdateCenter;
typeInfo->sumCenter = HalfvecSumCenter;
PG_RETURN_VOID(); PG_RETURN_VOID();
}; };
PGDLLEXPORT PG_FUNCTION_INFO_V1(ivfflat_halfvec_sum_center); PGDLLEXPORT PG_FUNCTION_INFO_V1(ivfflat_bit_support);
Datum Datum
ivfflat_halfvec_sum_center(PG_FUNCTION_ARGS) ivfflat_bit_support(PG_FUNCTION_ARGS)
{ {
HalfVector *vec = PG_GETARG_HALFVEC_P(0); IvfflatTypeInfo *typeInfo = (IvfflatTypeInfo *) PG_GETARG_POINTER(0);
float *x = (float *) PG_GETARG_POINTER(1);
for (int k = 0; k < vec->dim; k++) typeInfo->maxDimensions = IVFFLAT_MAX_DIM * 32;
x[k] += HalfToFloat4(vec->x[k]); typeInfo->itemsize = VARBITTOTALLEN(typeInfo->dimensions);
typeInfo->updateCenter = BitUpdateCenter;
typeInfo->sumCenter = BitSumCenter;
PG_RETURN_VOID(); PG_RETURN_VOID();
} };
PGDLLEXPORT PG_FUNCTION_INFO_V1(ivfflat_bit_sum_center);
Datum
ivfflat_bit_sum_center(PG_FUNCTION_ARGS)
{
VarBit *vec = PG_GETARG_VARBIT_P(0);
float *x = (float *) PG_GETARG_POINTER(1);
for (int k = 0; k < VARBITLEN(vec); k++)
x[k] += (float) (((VARBITS(vec)[k / 8]) >> (7 - (k % 8))) & 0x01);
PG_RETURN_VOID();
}