mirror of
https://github.com/pgvector/pgvector.git
synced 2026-06-06 05:51:21 +08:00
Removed type-specific code from IVFFlat - #527
This commit is contained in:
@@ -23,28 +23,40 @@ CREATE OPERATOR || (
|
||||
);
|
||||
|
||||
CREATE FUNCTION ivfflat_bit_max_dims(internal) RETURNS internal
|
||||
AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE;
|
||||
AS 'MODULE_PATHNAME' LANGUAGE C;
|
||||
|
||||
CREATE FUNCTION ivfflat_halfvec_max_dims(internal) RETURNS internal
|
||||
AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE;
|
||||
AS 'MODULE_PATHNAME' LANGUAGE C;
|
||||
|
||||
CREATE FUNCTION ivfflat_bit_support(internal) RETURNS internal
|
||||
AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE;
|
||||
CREATE FUNCTION ivfflat_vector_update_center(internal, internal, internal) RETURNS internal
|
||||
AS 'MODULE_PATHNAME' LANGUAGE C;
|
||||
|
||||
CREATE FUNCTION ivfflat_halfvec_support(internal) RETURNS internal
|
||||
AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE;
|
||||
CREATE FUNCTION ivfflat_bit_update_center(internal, internal, internal) RETURNS internal
|
||||
AS 'MODULE_PATHNAME' LANGUAGE C;
|
||||
|
||||
CREATE FUNCTION ivfflat_halfvec_update_center(internal, internal, internal) RETURNS internal
|
||||
AS 'MODULE_PATHNAME' LANGUAGE C;
|
||||
|
||||
CREATE FUNCTION ivfflat_vector_sum_center(internal, internal) RETURNS internal
|
||||
AS 'MODULE_PATHNAME' LANGUAGE C;
|
||||
|
||||
CREATE FUNCTION ivfflat_bit_sum_center(internal, internal) RETURNS internal
|
||||
AS 'MODULE_PATHNAME' LANGUAGE C;
|
||||
|
||||
CREATE FUNCTION ivfflat_halfvec_sum_center(internal, internal) RETURNS internal
|
||||
AS 'MODULE_PATHNAME' LANGUAGE C;
|
||||
|
||||
CREATE FUNCTION hnsw_bit_max_dims(internal) RETURNS internal
|
||||
AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE;
|
||||
AS 'MODULE_PATHNAME' LANGUAGE C;
|
||||
|
||||
CREATE FUNCTION hnsw_halfvec_max_dims(internal) RETURNS internal
|
||||
AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE;
|
||||
AS 'MODULE_PATHNAME' LANGUAGE C;
|
||||
|
||||
CREATE FUNCTION hnsw_sparsevec_max_dims(internal) RETURNS internal
|
||||
AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE;
|
||||
AS 'MODULE_PATHNAME' LANGUAGE C;
|
||||
|
||||
CREATE FUNCTION hnsw_sparsevec_check_value(internal) RETURNS internal
|
||||
AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE;
|
||||
AS 'MODULE_PATHNAME' LANGUAGE C;
|
||||
|
||||
CREATE OPERATOR CLASS vector_l1_ops
|
||||
FOR TYPE vector USING hnsw AS
|
||||
@@ -73,7 +85,8 @@ CREATE OPERATOR CLASS bit_hamming_ops
|
||||
FUNCTION 1 hamming_distance(bit, bit),
|
||||
FUNCTION 3 hamming_distance(bit, bit),
|
||||
FUNCTION 6 ivfflat_bit_max_dims(internal),
|
||||
FUNCTION 7 ivfflat_bit_support(internal);
|
||||
FUNCTION 7 ivfflat_bit_update_center(internal, internal, internal),
|
||||
FUNCTION 8 ivfflat_bit_sum_center(internal, internal);
|
||||
|
||||
CREATE OPERATOR CLASS bit_hamming_ops
|
||||
FOR TYPE bit USING hnsw AS
|
||||
@@ -341,7 +354,8 @@ CREATE OPERATOR CLASS halfvec_l2_ops
|
||||
FUNCTION 1 halfvec_l2_squared_distance(halfvec, halfvec),
|
||||
FUNCTION 3 l2_distance(halfvec, halfvec),
|
||||
FUNCTION 6 ivfflat_halfvec_max_dims(internal),
|
||||
FUNCTION 7 ivfflat_halfvec_support(internal);
|
||||
FUNCTION 7 ivfflat_halfvec_update_center(internal, internal, internal),
|
||||
FUNCTION 8 ivfflat_halfvec_sum_center(internal, internal);
|
||||
|
||||
CREATE OPERATOR CLASS halfvec_ip_ops
|
||||
FOR TYPE halfvec USING ivfflat AS
|
||||
@@ -351,7 +365,8 @@ CREATE OPERATOR CLASS halfvec_ip_ops
|
||||
FUNCTION 4 l2_norm(halfvec),
|
||||
FUNCTION 5 l2_normalize(halfvec),
|
||||
FUNCTION 6 ivfflat_halfvec_max_dims(internal),
|
||||
FUNCTION 7 ivfflat_halfvec_support(internal);
|
||||
FUNCTION 7 ivfflat_halfvec_update_center(internal, internal, internal),
|
||||
FUNCTION 8 ivfflat_halfvec_sum_center(internal, internal);
|
||||
|
||||
CREATE OPERATOR CLASS halfvec_cosine_ops
|
||||
FOR TYPE halfvec USING ivfflat AS
|
||||
@@ -362,7 +377,8 @@ CREATE OPERATOR CLASS halfvec_cosine_ops
|
||||
FUNCTION 4 l2_norm(halfvec),
|
||||
FUNCTION 5 l2_normalize(halfvec),
|
||||
FUNCTION 6 ivfflat_halfvec_max_dims(internal),
|
||||
FUNCTION 7 ivfflat_halfvec_support(internal);
|
||||
FUNCTION 7 ivfflat_halfvec_update_center(internal, internal, internal),
|
||||
FUNCTION 8 ivfflat_halfvec_sum_center(internal, internal);
|
||||
|
||||
CREATE OPERATOR CLASS halfvec_l2_ops
|
||||
FOR TYPE halfvec USING hnsw AS
|
||||
|
||||
@@ -264,28 +264,40 @@ COMMENT ON ACCESS METHOD hnsw IS 'hnsw index access method';
|
||||
-- access method private functions
|
||||
|
||||
CREATE FUNCTION ivfflat_bit_max_dims(internal) RETURNS internal
|
||||
AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE;
|
||||
AS 'MODULE_PATHNAME' LANGUAGE C;
|
||||
|
||||
CREATE FUNCTION ivfflat_halfvec_max_dims(internal) RETURNS internal
|
||||
AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE;
|
||||
AS 'MODULE_PATHNAME' LANGUAGE C;
|
||||
|
||||
CREATE FUNCTION ivfflat_bit_support(internal) RETURNS internal
|
||||
AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE;
|
||||
CREATE FUNCTION ivfflat_vector_update_center(internal, internal, internal) RETURNS internal
|
||||
AS 'MODULE_PATHNAME' LANGUAGE C;
|
||||
|
||||
CREATE FUNCTION ivfflat_halfvec_support(internal) RETURNS internal
|
||||
AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE;
|
||||
CREATE FUNCTION ivfflat_bit_update_center(internal, internal, internal) RETURNS internal
|
||||
AS 'MODULE_PATHNAME' LANGUAGE C;
|
||||
|
||||
CREATE FUNCTION ivfflat_halfvec_update_center(internal, internal, internal) RETURNS internal
|
||||
AS 'MODULE_PATHNAME' LANGUAGE C;
|
||||
|
||||
CREATE FUNCTION ivfflat_vector_sum_center(internal, internal) RETURNS internal
|
||||
AS 'MODULE_PATHNAME' LANGUAGE C;
|
||||
|
||||
CREATE FUNCTION ivfflat_bit_sum_center(internal, internal) RETURNS internal
|
||||
AS 'MODULE_PATHNAME' LANGUAGE C;
|
||||
|
||||
CREATE FUNCTION ivfflat_halfvec_sum_center(internal, internal) RETURNS internal
|
||||
AS 'MODULE_PATHNAME' LANGUAGE C;
|
||||
|
||||
CREATE FUNCTION hnsw_bit_max_dims(internal) RETURNS internal
|
||||
AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE;
|
||||
AS 'MODULE_PATHNAME' LANGUAGE C;
|
||||
|
||||
CREATE FUNCTION hnsw_halfvec_max_dims(internal) RETURNS internal
|
||||
AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE;
|
||||
AS 'MODULE_PATHNAME' LANGUAGE C;
|
||||
|
||||
CREATE FUNCTION hnsw_sparsevec_max_dims(internal) RETURNS internal
|
||||
AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE;
|
||||
AS 'MODULE_PATHNAME' LANGUAGE C;
|
||||
|
||||
CREATE FUNCTION hnsw_sparsevec_check_value(internal) RETURNS internal
|
||||
AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE;
|
||||
AS 'MODULE_PATHNAME' LANGUAGE C;
|
||||
|
||||
-- vector opclasses
|
||||
|
||||
@@ -368,7 +380,8 @@ CREATE OPERATOR CLASS bit_hamming_ops
|
||||
FUNCTION 1 hamming_distance(bit, bit),
|
||||
FUNCTION 3 hamming_distance(bit, bit),
|
||||
FUNCTION 6 ivfflat_bit_max_dims(internal),
|
||||
FUNCTION 7 ivfflat_bit_support(internal);
|
||||
FUNCTION 7 ivfflat_bit_update_center(internal, internal, internal),
|
||||
FUNCTION 8 ivfflat_bit_sum_center(internal, internal);
|
||||
|
||||
CREATE OPERATOR CLASS bit_hamming_ops
|
||||
FOR TYPE bit USING hnsw AS
|
||||
@@ -652,7 +665,8 @@ CREATE OPERATOR CLASS halfvec_l2_ops
|
||||
FUNCTION 1 halfvec_l2_squared_distance(halfvec, halfvec),
|
||||
FUNCTION 3 l2_distance(halfvec, halfvec),
|
||||
FUNCTION 6 ivfflat_halfvec_max_dims(internal),
|
||||
FUNCTION 7 ivfflat_halfvec_support(internal);
|
||||
FUNCTION 7 ivfflat_halfvec_update_center(internal, internal, internal),
|
||||
FUNCTION 8 ivfflat_halfvec_sum_center(internal, internal);
|
||||
|
||||
CREATE OPERATOR CLASS halfvec_ip_ops
|
||||
FOR TYPE halfvec USING ivfflat AS
|
||||
@@ -662,7 +676,8 @@ CREATE OPERATOR CLASS halfvec_ip_ops
|
||||
FUNCTION 4 l2_norm(halfvec),
|
||||
FUNCTION 5 l2_normalize(halfvec),
|
||||
FUNCTION 6 ivfflat_halfvec_max_dims(internal),
|
||||
FUNCTION 7 ivfflat_halfvec_support(internal);
|
||||
FUNCTION 7 ivfflat_halfvec_update_center(internal, internal, internal),
|
||||
FUNCTION 8 ivfflat_halfvec_sum_center(internal, internal);
|
||||
|
||||
CREATE OPERATOR CLASS halfvec_cosine_ops
|
||||
FOR TYPE halfvec USING ivfflat AS
|
||||
@@ -673,7 +688,8 @@ CREATE OPERATOR CLASS halfvec_cosine_ops
|
||||
FUNCTION 4 l2_norm(halfvec),
|
||||
FUNCTION 5 l2_normalize(halfvec),
|
||||
FUNCTION 6 ivfflat_halfvec_max_dims(internal),
|
||||
FUNCTION 7 ivfflat_halfvec_support(internal);
|
||||
FUNCTION 7 ivfflat_halfvec_update_center(internal, internal, internal),
|
||||
FUNCTION 8 ivfflat_halfvec_sum_center(internal, internal);
|
||||
|
||||
CREATE OPERATOR CLASS halfvec_l2_ops
|
||||
FOR TYPE halfvec USING hnsw AS
|
||||
|
||||
@@ -978,7 +978,7 @@ halfvec_subvector(PG_FUNCTION_ARGS)
|
||||
/*
|
||||
* Internal helper to compare half vectors
|
||||
*/
|
||||
int
|
||||
static int
|
||||
halfvec_cmp_internal(HalfVector * a, HalfVector * b)
|
||||
{
|
||||
int dim = Min(a->dim, b->dim);
|
||||
|
||||
@@ -44,6 +44,5 @@ typedef struct HalfVector
|
||||
} HalfVector;
|
||||
|
||||
HalfVector *InitHalfVector(int dim);
|
||||
int halfvec_cmp_internal(HalfVector * a, HalfVector * b);
|
||||
|
||||
#endif
|
||||
|
||||
@@ -319,27 +319,6 @@ InsertTuples(Relation index, IvfflatBuildState * buildstate, ForkNumber forkNum)
|
||||
}
|
||||
}
|
||||
|
||||
/*
|
||||
* Get type
|
||||
*/
|
||||
static IvfflatType
|
||||
IvfflatGetType(Relation index)
|
||||
{
|
||||
FmgrInfo *procinfo = IvfflatOptionalProcInfo(index, IVFFLAT_TYPE_SUPPORT_PROC);
|
||||
Oid typid = TupleDescAttr(index->rd_att, 0)->atttypid;
|
||||
IvfflatType result;
|
||||
|
||||
if (procinfo == NULL)
|
||||
return IVFFLAT_TYPE_VECTOR;
|
||||
|
||||
result = (IvfflatType) DatumGetInt32(FunctionCall1(procinfo, ObjectIdGetDatum(typid)));
|
||||
|
||||
if (result == IVFFLAT_TYPE_UNSUPPORTED)
|
||||
elog(ERROR, "type not supported for ivfflat index");
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
/*
|
||||
* Get max dimensions
|
||||
*/
|
||||
@@ -358,13 +337,14 @@ GetMaxDimensions(Relation index)
|
||||
* Get item size
|
||||
*/
|
||||
static Size
|
||||
GetItemSize(IvfflatType type, int dimensions)
|
||||
GetItemSize(int maxDimensions, int dimensions)
|
||||
{
|
||||
if (type == IVFFLAT_TYPE_VECTOR)
|
||||
/* TODO Improve */
|
||||
if (maxDimensions == IVFFLAT_MAX_DIM)
|
||||
return VECTOR_SIZE(dimensions);
|
||||
else if (type == IVFFLAT_TYPE_HALFVEC)
|
||||
else if (maxDimensions == IVFFLAT_MAX_DIM * 2)
|
||||
return HALFVEC_SIZE(dimensions);
|
||||
else if (type == IVFFLAT_TYPE_BIT)
|
||||
else if (maxDimensions == IVFFLAT_MAX_DIM * 32)
|
||||
return VARBITTOTALLEN(dimensions);
|
||||
else
|
||||
elog(ERROR, "Unsupported type");
|
||||
@@ -381,7 +361,6 @@ InitBuildState(IvfflatBuildState * buildstate, Relation heap, Relation index, In
|
||||
buildstate->heap = heap;
|
||||
buildstate->index = index;
|
||||
buildstate->indexInfo = indexInfo;
|
||||
buildstate->type = IvfflatGetType(index);
|
||||
|
||||
buildstate->lists = IvfflatGetLists(index);
|
||||
buildstate->dimensions = TupleDescAttr(index->rd_att, 0)->atttypmod;
|
||||
@@ -421,7 +400,7 @@ InitBuildState(IvfflatBuildState * buildstate, Relation heap, Relation index, In
|
||||
|
||||
buildstate->slot = MakeSingleTupleTableSlot(buildstate->tupdesc, &TTSOpsVirtual);
|
||||
|
||||
buildstate->centers = VectorArrayInit(buildstate->lists, buildstate->dimensions, GetItemSize(buildstate->type, buildstate->dimensions));
|
||||
buildstate->centers = VectorArrayInit(buildstate->lists, buildstate->dimensions, GetItemSize(maxDimensions, buildstate->dimensions));
|
||||
buildstate->listInfo = palloc(sizeof(ListInfo) * buildstate->lists);
|
||||
|
||||
buildstate->tmpCtx = AllocSetContextCreate(CurrentMemoryContext,
|
||||
@@ -491,7 +470,7 @@ ComputeCenters(IvfflatBuildState * buildstate)
|
||||
}
|
||||
|
||||
/* Calculate centers */
|
||||
IvfflatBench("k-means", IvfflatKmeans(buildstate->index, buildstate->samples, buildstate->centers, buildstate->type));
|
||||
IvfflatBench("k-means", IvfflatKmeans(buildstate->index, buildstate->samples, buildstate->centers));
|
||||
|
||||
/* Free samples before we allocate more memory */
|
||||
VectorArrayFree(buildstate->samples);
|
||||
|
||||
@@ -188,7 +188,7 @@ ivfflathandler(PG_FUNCTION_ARGS)
|
||||
IndexAmRoutine *amroutine = makeNode(IndexAmRoutine);
|
||||
|
||||
amroutine->amstrategies = 0;
|
||||
amroutine->amsupport = 7;
|
||||
amroutine->amsupport = 8;
|
||||
#if PG_VERSION_NUM >= 130000
|
||||
amroutine->amoptsprocnum = 0;
|
||||
#endif
|
||||
|
||||
@@ -30,7 +30,8 @@
|
||||
#define IVFFLAT_KMEANS_NORM_PROC 4
|
||||
#define IVFFLAT_NORMALIZE_PROC 5
|
||||
#define IVFFLAT_MAX_DIMS_PROC 6
|
||||
#define IVFFLAT_TYPE_SUPPORT_PROC 7
|
||||
#define IVFFLAT_UPDATE_CENTER_PROC 7
|
||||
#define IVFFLAT_SUM_CENTER_PROC 8
|
||||
|
||||
#define IVFFLAT_VERSION 1
|
||||
#define IVFFLAT_MAGIC_NUMBER 0x14FF1A7
|
||||
@@ -46,14 +47,6 @@
|
||||
#define IVFFLAT_MAX_LISTS 32768
|
||||
#define IVFFLAT_DEFAULT_PROBES 1
|
||||
|
||||
typedef enum IvfflatType
|
||||
{
|
||||
IVFFLAT_TYPE_VECTOR,
|
||||
IVFFLAT_TYPE_HALFVEC,
|
||||
IVFFLAT_TYPE_BIT,
|
||||
IVFFLAT_TYPE_UNSUPPORTED
|
||||
} IvfflatType;
|
||||
|
||||
/* Build phases */
|
||||
/* PROGRESS_CREATEIDX_SUBPHASE_INITIALIZE is 1 */
|
||||
#define PROGRESS_IVFFLAT_PHASE_KMEANS 2
|
||||
@@ -165,7 +158,6 @@ typedef struct IvfflatBuildState
|
||||
Relation heap;
|
||||
Relation index;
|
||||
IndexInfo *indexInfo;
|
||||
IvfflatType type;
|
||||
|
||||
/* Settings */
|
||||
int dimensions;
|
||||
@@ -279,7 +271,7 @@ typedef IvfflatScanOpaqueData * IvfflatScanOpaque;
|
||||
/* Methods */
|
||||
VectorArray VectorArrayInit(int maxlen, int dimensions, Size itemsize);
|
||||
void VectorArrayFree(VectorArray arr);
|
||||
void IvfflatKmeans(Relation index, VectorArray samples, VectorArray centers, IvfflatType type);
|
||||
void IvfflatKmeans(Relation index, VectorArray samples, VectorArray centers);
|
||||
FmgrInfo *IvfflatOptionalProcInfo(Relation index, uint16 procnum);
|
||||
Datum IvfflatNormValue(FmgrInfo *procinfo, Oid collation, Datum value);
|
||||
bool IvfflatCheckNorm(FmgrInfo *procinfo, Oid collation, Datum value);
|
||||
|
||||
291
src/ivfkmeans.c
291
src/ivfkmeans.c
@@ -13,14 +13,15 @@
|
||||
#include "utils/memutils.h"
|
||||
#include "vector.h"
|
||||
|
||||
/* Support functions */
|
||||
PGDLLEXPORT Datum ivfflat_vector_update_center(PG_FUNCTION_ARGS);
|
||||
PGDLLEXPORT Datum ivfflat_vector_sum_center(PG_FUNCTION_ARGS);
|
||||
|
||||
typedef struct KmeansState
|
||||
{
|
||||
void (*initCenter) (Pointer v, int dimensions);
|
||||
void (*updateCenter) (Pointer v, float *x);
|
||||
void (*sumCenter) (Pointer v, float *x);
|
||||
int (*comp) (const void *a, const void *b);
|
||||
bool separateAgg;
|
||||
bool checkDuplicates;
|
||||
int dimensions;
|
||||
FmgrInfo *updatecenterprocinfo;
|
||||
FmgrInfo *sumcenterprocinfo;
|
||||
} KmeansState;
|
||||
|
||||
/*
|
||||
@@ -126,105 +127,20 @@ NormCenters(FmgrInfo *normalizeprocinfo, Oid collation, VectorArray centers)
|
||||
MemoryContextDelete(normCtx);
|
||||
}
|
||||
|
||||
/*
|
||||
* Compare vectors
|
||||
*/
|
||||
static int
|
||||
CompareVectors(const void *a, const void *b)
|
||||
static void
|
||||
UpdateCenter(FmgrInfo *procinfo, Pointer center, int dimensions, float *x)
|
||||
{
|
||||
return vector_cmp_internal((Vector *) a, (Vector *) b);
|
||||
if (procinfo == NULL)
|
||||
DirectFunctionCall3(ivfflat_vector_update_center, PointerGetDatum(center), Int32GetDatum(dimensions), PointerGetDatum(x));
|
||||
else
|
||||
FunctionCall3(procinfo, PointerGetDatum(center), Int32GetDatum(dimensions), PointerGetDatum(x));
|
||||
}
|
||||
|
||||
/*
|
||||
* Compare half vectors
|
||||
*/
|
||||
static int
|
||||
CompareHalfVectors(const void *a, const void *b)
|
||||
{
|
||||
return halfvec_cmp_internal((HalfVector *) a, (HalfVector *) b);
|
||||
}
|
||||
|
||||
/*
|
||||
* Compare bit vectors
|
||||
*/
|
||||
static int
|
||||
CompareBitVectors(const void *a, const void *b)
|
||||
{
|
||||
return DirectFunctionCall2(bitcmp, VarBitPGetDatum((VarBit *) a), VarBitPGetDatum((VarBit *) b));
|
||||
}
|
||||
|
||||
/*
|
||||
* Sort vector array
|
||||
* Quick approach if we have no data
|
||||
*/
|
||||
static void
|
||||
SortVectorArray(VectorArray arr, KmeansState * kmeansstate)
|
||||
{
|
||||
qsort(arr->items, arr->length, arr->itemsize, kmeansstate->comp);
|
||||
}
|
||||
|
||||
static void
|
||||
VectorInitCenter(Pointer v, int dimensions)
|
||||
{
|
||||
Vector *vec = (Vector *) v;
|
||||
|
||||
SET_VARSIZE(vec, VECTOR_SIZE(dimensions));
|
||||
vec->dim = dimensions;
|
||||
}
|
||||
|
||||
static void
|
||||
HalfvecInitCenter(Pointer v, int dimensions)
|
||||
{
|
||||
HalfVector *vec = (HalfVector *) v;
|
||||
|
||||
SET_VARSIZE(vec, HALFVEC_SIZE(dimensions));
|
||||
vec->dim = dimensions;
|
||||
}
|
||||
|
||||
static void
|
||||
BitInitCenter(Pointer v, int dimensions)
|
||||
{
|
||||
VarBit *vec = (VarBit *) v;
|
||||
|
||||
SET_VARSIZE(vec, VARBITTOTALLEN(dimensions));
|
||||
VARBITLEN(vec) = dimensions;
|
||||
}
|
||||
|
||||
static void
|
||||
VectorUpdateCenter(Pointer v, float *x)
|
||||
{
|
||||
Vector *newCenter = (Vector *) v;
|
||||
|
||||
for (int k = 0; k < newCenter->dim; k++)
|
||||
newCenter->x[k] = x[k];
|
||||
}
|
||||
|
||||
static void
|
||||
HalfvecUpdateCenter(Pointer v, float *x)
|
||||
{
|
||||
HalfVector *newCenter = (HalfVector *) v;
|
||||
|
||||
for (int k = 0; k < newCenter->dim; k++)
|
||||
newCenter->x[k] = Float4ToHalfUnchecked(x[k]);
|
||||
}
|
||||
|
||||
static void
|
||||
BitUpdateCenter(Pointer v, float *x)
|
||||
{
|
||||
VarBit *newCenter = (VarBit *) v;
|
||||
unsigned char *nx = VARBITS(newCenter);
|
||||
|
||||
for (uint32 k = 0; k < VARBITBYTES(newCenter); k++)
|
||||
nx[k] = 0;
|
||||
|
||||
for (int k = 0; k < VARBITLEN(newCenter); k++)
|
||||
nx[k / 8] |= (x[k] > 0.5 ? 1 : 0) << (7 - (k % 8));
|
||||
}
|
||||
|
||||
/*
|
||||
* Quick approach if we have little data
|
||||
*/
|
||||
static void
|
||||
QuickCenters(Relation index, VectorArray samples, VectorArray centers, KmeansState * kmeansstate)
|
||||
RandomCenters(Relation index, VectorArray centers, KmeansState * kmeansstate)
|
||||
{
|
||||
int dimensions = centers->dim;
|
||||
Oid collation = index->rd_indcollation[0];
|
||||
@@ -232,24 +148,7 @@ QuickCenters(Relation index, VectorArray samples, VectorArray centers, KmeansSta
|
||||
FmgrInfo *normalizeprocinfo = IvfflatOptionalProcInfo(index, IVFFLAT_NORMALIZE_PROC);
|
||||
float *x = (float *) palloc(sizeof(float) * dimensions);
|
||||
|
||||
/* Copy existing vectors while avoiding duplicates */
|
||||
if (samples->length > 0)
|
||||
{
|
||||
SortVectorArray(samples, kmeansstate);
|
||||
|
||||
for (int i = 0; i < samples->length; i++)
|
||||
{
|
||||
Datum vec = PointerGetDatum(VectorArrayGet(samples, i));
|
||||
|
||||
if (i == 0 || !datumIsEqual(vec, PointerGetDatum(VectorArrayGet(samples, i - 1)), false, -1))
|
||||
{
|
||||
VectorArraySet(centers, centers->length, DatumGetPointer(vec));
|
||||
centers->length++;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/* Fill remaining with random data */
|
||||
/* Fill with random data */
|
||||
while (centers->length < centers->maxlen)
|
||||
{
|
||||
Pointer center = VectorArrayGet(centers, centers->length);
|
||||
@@ -257,13 +156,11 @@ QuickCenters(Relation index, VectorArray samples, VectorArray centers, KmeansSta
|
||||
for (int i = 0; i < dimensions; i++)
|
||||
x[i] = (float) RandomDouble();
|
||||
|
||||
kmeansstate->initCenter(center, dimensions);
|
||||
kmeansstate->updateCenter(center, x);
|
||||
UpdateCenter(kmeansstate->updatecenterprocinfo, center, dimensions, x);
|
||||
|
||||
centers->length++;
|
||||
}
|
||||
|
||||
/* Fine if existing vectors are normalized twice */
|
||||
if (normprocinfo != NULL)
|
||||
NormCenters(normalizeprocinfo, collation, centers);
|
||||
|
||||
@@ -288,57 +185,39 @@ ShowMemoryUsage(MemoryContext context, Size estimatedSize)
|
||||
#endif
|
||||
|
||||
static void
|
||||
VectorSumCenter(Pointer v, float *x)
|
||||
SumCenter(FmgrInfo *procinfo, Pointer sample, float *x)
|
||||
{
|
||||
Vector *vec = (Vector *) v;
|
||||
|
||||
for (int k = 0; k < vec->dim; k++)
|
||||
x[k] += vec->x[k];
|
||||
}
|
||||
|
||||
static void
|
||||
HalfvecSumCenter(Pointer v, float *x)
|
||||
{
|
||||
HalfVector *vec = (HalfVector *) v;
|
||||
|
||||
for (int k = 0; k < vec->dim; k++)
|
||||
x[k] += HalfToFloat4(vec->x[k]);
|
||||
}
|
||||
|
||||
static void
|
||||
BitSumCenter(Pointer v, float *x)
|
||||
{
|
||||
VarBit *vec = (VarBit *) v;
|
||||
|
||||
for (int k = 0; k < VARBITLEN(v); k++)
|
||||
x[k] += (float) (((VARBITS(vec)[k / 8]) >> (7 - (k % 8))) & 0x01);
|
||||
if (procinfo == NULL)
|
||||
DirectFunctionCall2(ivfflat_vector_sum_center, PointerGetDatum(sample), PointerGetDatum(x));
|
||||
else
|
||||
FunctionCall2(procinfo, PointerGetDatum(sample), PointerGetDatum(x));
|
||||
}
|
||||
|
||||
/*
|
||||
* Sum centers
|
||||
*/
|
||||
static void
|
||||
SumCenters(VectorArray samples, VectorArray aggCenters, int *closestCenters, KmeansState * kmeansstate)
|
||||
SumCenters(VectorArray samples, float *agg, int *closestCenters, KmeansState * kmeansstate)
|
||||
{
|
||||
for (int j = 0; j < samples->length; j++)
|
||||
{
|
||||
Vector *aggCenter = (Vector *) VectorArrayGet(aggCenters, closestCenters[j]);
|
||||
float *x = agg + ((int64) closestCenters[j] * kmeansstate->dimensions);
|
||||
|
||||
kmeansstate->sumCenter(VectorArrayGet(samples, j), aggCenter->x);
|
||||
SumCenter(kmeansstate->sumcenterprocinfo, VectorArrayGet(samples, j), x);
|
||||
}
|
||||
}
|
||||
|
||||
/*
|
||||
* Set new centers
|
||||
* Update centers
|
||||
*/
|
||||
static void
|
||||
UpdateCenters(VectorArray aggCenters, VectorArray newCenters, KmeansState * kmeansstate)
|
||||
UpdateCenters(float *agg, VectorArray centers, KmeansState * kmeansstate)
|
||||
{
|
||||
for (int j = 0; j < aggCenters->length; j++)
|
||||
for (int j = 0; j < centers->length; j++)
|
||||
{
|
||||
Vector *aggCenter = (Vector *) VectorArrayGet(aggCenters, j);
|
||||
float *x = agg + ((int64) j * kmeansstate->dimensions);
|
||||
|
||||
kmeansstate->updateCenter(VectorArrayGet(newCenters, j), aggCenter->x);
|
||||
UpdateCenter(kmeansstate->updatecenterprocinfo, VectorArrayGet(centers, j), centers->dim, x);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -346,25 +225,25 @@ UpdateCenters(VectorArray aggCenters, VectorArray newCenters, KmeansState * kmea
|
||||
* Compute new centers
|
||||
*/
|
||||
static void
|
||||
ComputeNewCenters(VectorArray samples, VectorArray aggCenters, VectorArray newCenters, int *centerCounts, int *closestCenters, FmgrInfo *normprocinfo, FmgrInfo *normalizeprocinfo, Oid collation, KmeansState * kmeansstate)
|
||||
ComputeNewCenters(VectorArray samples, float *agg, VectorArray newCenters, int *centerCounts, int *closestCenters, FmgrInfo *normprocinfo, FmgrInfo *normalizeprocinfo, Oid collation, KmeansState * kmeansstate)
|
||||
{
|
||||
int dimensions = aggCenters->dim;
|
||||
int numCenters = aggCenters->maxlen;
|
||||
int dimensions = kmeansstate->dimensions;
|
||||
int numCenters = newCenters->length;
|
||||
int numSamples = samples->length;
|
||||
|
||||
/* Reset sum and count */
|
||||
for (int j = 0; j < numCenters; j++)
|
||||
{
|
||||
Vector *vec = (Vector *) VectorArrayGet(aggCenters, j);
|
||||
float *x = agg + ((int64) j * dimensions);
|
||||
|
||||
for (int k = 0; k < dimensions; k++)
|
||||
vec->x[k] = 0.0;
|
||||
x[k] = 0.0;
|
||||
|
||||
centerCounts[j] = 0;
|
||||
}
|
||||
|
||||
/* Increment sum of closest center */
|
||||
SumCenters(samples, aggCenters, closestCenters, kmeansstate);
|
||||
SumCenters(samples, agg, closestCenters, kmeansstate);
|
||||
|
||||
/* Increment count of closest center */
|
||||
for (int j = 0; j < numSamples; j++)
|
||||
@@ -373,7 +252,7 @@ ComputeNewCenters(VectorArray samples, VectorArray aggCenters, VectorArray newCe
|
||||
/* Divide sum by count */
|
||||
for (int j = 0; j < numCenters; j++)
|
||||
{
|
||||
Vector *vec = (Vector *) VectorArrayGet(aggCenters, j);
|
||||
float *x = agg + ((int64) j * dimensions);
|
||||
|
||||
if (centerCounts[j] > 0)
|
||||
{
|
||||
@@ -381,24 +260,23 @@ ComputeNewCenters(VectorArray samples, VectorArray aggCenters, VectorArray newCe
|
||||
/* TODO Update bounds */
|
||||
for (int k = 0; k < dimensions; k++)
|
||||
{
|
||||
if (isinf(vec->x[k]))
|
||||
vec->x[k] = vec->x[k] > 0 ? FLT_MAX : -FLT_MAX;
|
||||
if (isinf(x[k]))
|
||||
x[k] = x[k] > 0 ? FLT_MAX : -FLT_MAX;
|
||||
}
|
||||
|
||||
for (int k = 0; k < dimensions; k++)
|
||||
vec->x[k] /= centerCounts[j];
|
||||
x[k] /= centerCounts[j];
|
||||
}
|
||||
else
|
||||
{
|
||||
/* TODO Handle empty centers properly */
|
||||
for (int k = 0; k < dimensions; k++)
|
||||
vec->x[k] = RandomDouble();
|
||||
x[k] = RandomDouble();
|
||||
}
|
||||
}
|
||||
|
||||
/* Set new centers if different from agg centers */
|
||||
if (kmeansstate->separateAgg)
|
||||
UpdateCenters(aggCenters, newCenters, kmeansstate);
|
||||
/* Set new centers */
|
||||
UpdateCenters(agg, newCenters, kmeansstate);
|
||||
|
||||
/* Normalize if needed */
|
||||
if (normprocinfo != NULL)
|
||||
@@ -424,7 +302,7 @@ ElkanKmeans(Relation index, VectorArray samples, VectorArray centers, KmeansStat
|
||||
int numCenters = centers->maxlen;
|
||||
int numSamples = samples->length;
|
||||
VectorArray newCenters;
|
||||
VectorArray aggCenters;
|
||||
float *agg;
|
||||
int *centerCounts;
|
||||
int *closestCenters;
|
||||
float *lowerBound;
|
||||
@@ -439,7 +317,7 @@ ElkanKmeans(Relation index, VectorArray samples, VectorArray centers, KmeansStat
|
||||
Size samplesSize = VECTOR_ARRAY_SIZE(samples->maxlen, samples->itemsize);
|
||||
Size centersSize = VECTOR_ARRAY_SIZE(centers->maxlen, centers->itemsize);
|
||||
Size newCentersSize = VECTOR_ARRAY_SIZE(numCenters, centers->itemsize);
|
||||
Size aggCentersSize = !kmeansstate->separateAgg ? 0 : VECTOR_ARRAY_SIZE(numCenters, VECTOR_SIZE(dimensions));
|
||||
Size aggSize = sizeof(float) * (int64) numCenters * dimensions;
|
||||
Size centerCountsSize = sizeof(int) * numCenters;
|
||||
Size closestCentersSize = sizeof(int) * numSamples;
|
||||
Size lowerBoundSize = sizeof(float) * numSamples * numCenters;
|
||||
@@ -449,7 +327,7 @@ ElkanKmeans(Relation index, VectorArray samples, VectorArray centers, KmeansStat
|
||||
Size newcdistSize = sizeof(float) * numCenters;
|
||||
|
||||
/* Calculate total size */
|
||||
Size totalSize = samplesSize + centersSize + newCentersSize + aggCentersSize + centerCountsSize + closestCentersSize + lowerBoundSize + upperBoundSize + sSize + halfcdistSize + newcdistSize;
|
||||
Size totalSize = samplesSize + centersSize + newCentersSize + aggSize + centerCountsSize + closestCentersSize + lowerBoundSize + upperBoundSize + sSize + halfcdistSize + newcdistSize;
|
||||
|
||||
/* Check memory requirements */
|
||||
/* Add one to error message to ceil */
|
||||
@@ -477,6 +355,7 @@ ElkanKmeans(Relation index, VectorArray samples, VectorArray centers, KmeansStat
|
||||
|
||||
/* Allocate space */
|
||||
/* Use float instead of double to save memory */
|
||||
agg = palloc(aggSize);
|
||||
centerCounts = palloc(centerCountsSize);
|
||||
closestCenters = palloc(closestCentersSize);
|
||||
lowerBound = palloc_extended(lowerBoundSize, MCXT_ALLOC_HUGE);
|
||||
@@ -489,24 +368,6 @@ ElkanKmeans(Relation index, VectorArray samples, VectorArray centers, KmeansStat
|
||||
newCenters = VectorArrayInit(numCenters, dimensions, centers->itemsize);
|
||||
newCenters->length = numCenters;
|
||||
|
||||
for (int j = 0; j < numCenters; j++)
|
||||
kmeansstate->initCenter(VectorArrayGet(newCenters, j), dimensions);
|
||||
|
||||
/* Initialize agg centers */
|
||||
if (!kmeansstate->separateAgg)
|
||||
{
|
||||
/* Use same centers to save memory */
|
||||
aggCenters = newCenters;
|
||||
}
|
||||
else
|
||||
{
|
||||
aggCenters = VectorArrayInit(numCenters, dimensions, VECTOR_SIZE(dimensions));
|
||||
aggCenters->length = numCenters;
|
||||
|
||||
for (int j = 0; j < numCenters; j++)
|
||||
VectorInitCenter(VectorArrayGet(aggCenters, j), dimensions);
|
||||
}
|
||||
|
||||
#ifdef IVFFLAT_MEMORY
|
||||
ShowMemoryUsage(oldCtx, totalSize);
|
||||
#endif
|
||||
@@ -645,7 +506,7 @@ ElkanKmeans(Relation index, VectorArray samples, VectorArray centers, KmeansStat
|
||||
}
|
||||
|
||||
/* Step 4: For each center c, let m(c) be mean of all points assigned */
|
||||
ComputeNewCenters(samples, aggCenters, newCenters, centerCounts, closestCenters, normprocinfo, normalizeprocinfo, collation, kmeansstate);
|
||||
ComputeNewCenters(samples, agg, newCenters, centerCounts, closestCenters, normprocinfo, normalizeprocinfo, collation, kmeansstate);
|
||||
|
||||
/* Step 5 */
|
||||
for (int j = 0; j < numCenters; j++)
|
||||
@@ -699,7 +560,7 @@ CheckCenters(Relation index, VectorArray centers, KmeansState * kmeansstate)
|
||||
for (int j = 0; j < centers->dim; j++)
|
||||
scratch[j] = 0;
|
||||
|
||||
kmeansstate->sumCenter(VectorArrayGet(centers, i), scratch);
|
||||
SumCenter(kmeansstate->sumcenterprocinfo, VectorArrayGet(centers, i), scratch);
|
||||
|
||||
for (int j = 0; j < centers->dim; j++)
|
||||
{
|
||||
@@ -711,18 +572,6 @@ CheckCenters(Relation index, VectorArray centers, KmeansState * kmeansstate)
|
||||
}
|
||||
}
|
||||
|
||||
if (kmeansstate->checkDuplicates)
|
||||
{
|
||||
/* Ensure no duplicate centers */
|
||||
SortVectorArray(centers, kmeansstate);
|
||||
|
||||
for (int i = 1; i < centers->length; i++)
|
||||
{
|
||||
if (datumIsEqual(PointerGetDatum(VectorArrayGet(centers, i)), PointerGetDatum(VectorArrayGet(centers, i - 1)), false, -1))
|
||||
elog(ERROR, "Duplicate centers detected. Please report a bug.");
|
||||
}
|
||||
}
|
||||
|
||||
/* Ensure no zero vectors for cosine distance */
|
||||
/* Check NORM_PROC instead of KMEANS_NORM_PROC */
|
||||
normprocinfo = IvfflatOptionalProcInfo(index, IVFFLAT_NORM_PROC);
|
||||
@@ -743,37 +592,11 @@ CheckCenters(Relation index, VectorArray centers, KmeansState * kmeansstate)
|
||||
}
|
||||
|
||||
static void
|
||||
InitKmeansState(KmeansState * kmeansstate, IvfflatType type)
|
||||
InitKmeansState(KmeansState * kmeansstate, Relation index, int dimensions)
|
||||
{
|
||||
if (type == IVFFLAT_TYPE_VECTOR)
|
||||
{
|
||||
kmeansstate->initCenter = VectorInitCenter;
|
||||
kmeansstate->updateCenter = VectorUpdateCenter;
|
||||
kmeansstate->sumCenter = VectorSumCenter;
|
||||
kmeansstate->comp = CompareVectors;
|
||||
kmeansstate->separateAgg = false;
|
||||
kmeansstate->checkDuplicates = true;
|
||||
}
|
||||
else if (type == IVFFLAT_TYPE_HALFVEC)
|
||||
{
|
||||
kmeansstate->initCenter = HalfvecInitCenter;
|
||||
kmeansstate->updateCenter = HalfvecUpdateCenter;
|
||||
kmeansstate->sumCenter = HalfvecSumCenter;
|
||||
kmeansstate->comp = CompareHalfVectors;
|
||||
kmeansstate->separateAgg = true;
|
||||
kmeansstate->checkDuplicates = true;
|
||||
}
|
||||
else if (type == IVFFLAT_TYPE_BIT)
|
||||
{
|
||||
kmeansstate->initCenter = BitInitCenter;
|
||||
kmeansstate->updateCenter = BitUpdateCenter;
|
||||
kmeansstate->sumCenter = BitSumCenter;
|
||||
kmeansstate->comp = CompareBitVectors;
|
||||
kmeansstate->separateAgg = true;
|
||||
kmeansstate->checkDuplicates = false;
|
||||
}
|
||||
else
|
||||
elog(ERROR, "Unsupported type");
|
||||
kmeansstate->dimensions = dimensions;
|
||||
kmeansstate->updatecenterprocinfo = IvfflatOptionalProcInfo(index, IVFFLAT_UPDATE_CENTER_PROC);
|
||||
kmeansstate->sumcenterprocinfo = IvfflatOptionalProcInfo(index, IVFFLAT_SUM_CENTER_PROC);
|
||||
}
|
||||
|
||||
/*
|
||||
@@ -781,14 +604,14 @@ InitKmeansState(KmeansState * kmeansstate, IvfflatType type)
|
||||
* We use spherical k-means for inner product and cosine
|
||||
*/
|
||||
void
|
||||
IvfflatKmeans(Relation index, VectorArray samples, VectorArray centers, IvfflatType type)
|
||||
IvfflatKmeans(Relation index, VectorArray samples, VectorArray centers)
|
||||
{
|
||||
KmeansState kmeansstate;
|
||||
|
||||
InitKmeansState(&kmeansstate, type);
|
||||
InitKmeansState(&kmeansstate, index, centers->dim);
|
||||
|
||||
if (samples->length <= centers->maxlen)
|
||||
QuickCenters(index, samples, centers, &kmeansstate);
|
||||
if (samples->length == 0)
|
||||
RandomCenters(index, centers, &kmeansstate);
|
||||
else
|
||||
ElkanKmeans(index, samples, centers, &kmeansstate);
|
||||
|
||||
|
||||
@@ -1,8 +1,11 @@
|
||||
#include "postgres.h"
|
||||
|
||||
#include "access/generic_xlog.h"
|
||||
#include "bitvec.h"
|
||||
#include "catalog/pg_type.h"
|
||||
#include "fmgr.h"
|
||||
#include "halfutils.h"
|
||||
#include "halfvec.h"
|
||||
#include "ivfflat.h"
|
||||
#include "storage/bufmgr.h"
|
||||
|
||||
@@ -239,21 +242,96 @@ ivfflat_bit_max_dims(PG_FUNCTION_ARGS)
|
||||
PG_RETURN_INT32(IVFFLAT_MAX_DIM * 32);
|
||||
};
|
||||
|
||||
PGDLLEXPORT PG_FUNCTION_INFO_V1(ivfflat_halfvec_support);
|
||||
PGDLLEXPORT PG_FUNCTION_INFO_V1(ivfflat_vector_update_center);
|
||||
Datum
|
||||
ivfflat_halfvec_support(PG_FUNCTION_ARGS)
|
||||
ivfflat_vector_update_center(PG_FUNCTION_ARGS)
|
||||
{
|
||||
PG_RETURN_INT32(IVFFLAT_TYPE_HALFVEC);
|
||||
Vector *vec = PG_GETARG_VECTOR_P(0);
|
||||
int dimensions = PG_GETARG_INT32(1);
|
||||
float *x = (float *) PG_GETARG_POINTER(2);
|
||||
|
||||
SET_VARSIZE(vec, VECTOR_SIZE(dimensions));
|
||||
vec->dim = dimensions;
|
||||
|
||||
for (int k = 0; k < dimensions; k++)
|
||||
vec->x[k] = x[k];
|
||||
|
||||
PG_RETURN_VOID();
|
||||
};
|
||||
|
||||
PGDLLEXPORT PG_FUNCTION_INFO_V1(ivfflat_bit_support);
|
||||
PGDLLEXPORT PG_FUNCTION_INFO_V1(ivfflat_halfvec_update_center);
|
||||
Datum
|
||||
ivfflat_bit_support(PG_FUNCTION_ARGS)
|
||||
ivfflat_halfvec_update_center(PG_FUNCTION_ARGS)
|
||||
{
|
||||
Oid typid = PG_GETARG_OID(0);
|
||||
HalfVector *vec = PG_GETARG_HALFVEC_P(0);
|
||||
int dimensions = PG_GETARG_INT32(1);
|
||||
float *x = (float *) PG_GETARG_POINTER(2);
|
||||
|
||||
if (typid == BITOID)
|
||||
PG_RETURN_INT32(IVFFLAT_TYPE_BIT);
|
||||
else
|
||||
PG_RETURN_INT32(IVFFLAT_TYPE_UNSUPPORTED);
|
||||
SET_VARSIZE(vec, HALFVEC_SIZE(dimensions));
|
||||
vec->dim = dimensions;
|
||||
|
||||
for (int k = 0; k < dimensions; k++)
|
||||
vec->x[k] = Float4ToHalfUnchecked(x[k]);
|
||||
|
||||
PG_RETURN_VOID();
|
||||
};
|
||||
|
||||
PGDLLEXPORT PG_FUNCTION_INFO_V1(ivfflat_bit_update_center);
|
||||
Datum
|
||||
ivfflat_bit_update_center(PG_FUNCTION_ARGS)
|
||||
{
|
||||
VarBit *vec = PG_GETARG_VARBIT_P(0);
|
||||
int dimensions = PG_GETARG_INT32(1);
|
||||
float *x = (float *) PG_GETARG_POINTER(2);
|
||||
unsigned char *nx = VARBITS(vec);
|
||||
|
||||
SET_VARSIZE(vec, VARBITTOTALLEN(dimensions));
|
||||
VARBITLEN(vec) = dimensions;
|
||||
|
||||
for (uint32 k = 0; k < VARBITBYTES(vec); k++)
|
||||
nx[k] = 0;
|
||||
|
||||
for (int k = 0; k < dimensions; k++)
|
||||
nx[k / 8] |= (x[k] > 0.5 ? 1 : 0) << (7 - (k % 8));
|
||||
|
||||
PG_RETURN_VOID();
|
||||
};
|
||||
|
||||
PGDLLEXPORT PG_FUNCTION_INFO_V1(ivfflat_vector_sum_center);
|
||||
Datum
|
||||
ivfflat_vector_sum_center(PG_FUNCTION_ARGS)
|
||||
{
|
||||
Vector *vec = PG_GETARG_VECTOR_P(0);
|
||||
float *x = (float *) PG_GETARG_POINTER(1);
|
||||
|
||||
for (int k = 0; k < vec->dim; k++)
|
||||
x[k] += vec->x[k];
|
||||
|
||||
PG_RETURN_VOID();
|
||||
};
|
||||
|
||||
PGDLLEXPORT PG_FUNCTION_INFO_V1(ivfflat_halfvec_sum_center);
|
||||
Datum
|
||||
ivfflat_halfvec_sum_center(PG_FUNCTION_ARGS)
|
||||
{
|
||||
HalfVector *vec = PG_GETARG_HALFVEC_P(0);
|
||||
float *x = (float *) PG_GETARG_POINTER(1);
|
||||
|
||||
for (int k = 0; k < vec->dim; k++)
|
||||
x[k] += HalfToFloat4(vec->x[k]);
|
||||
|
||||
PG_RETURN_VOID();
|
||||
}
|
||||
|
||||
PGDLLEXPORT PG_FUNCTION_INFO_V1(ivfflat_bit_sum_center);
|
||||
Datum
|
||||
ivfflat_bit_sum_center(PG_FUNCTION_ARGS)
|
||||
{
|
||||
VarBit *vec = PG_GETARG_VARBIT_P(0);
|
||||
float *x = (float *) PG_GETARG_POINTER(1);
|
||||
|
||||
for (int k = 0; k < VARBITLEN(vec); k++)
|
||||
x[k] += (float) (((VARBITS(vec)[k / 8]) >> (7 - (k % 8))) & 0x01);
|
||||
|
||||
PG_RETURN_VOID();
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user