Files
pgvector/src/ivfutils.c
Jon Daniel fe697e8788 vectorize: optimize VectorSumCenter and HalfvecSumCenter (#860)
* vectorize: optimize VectorSumCenter and HalfvecSumCenter

The functions VectorSumCenter and HalfvecSumCenter were not being
vectorized by the compiler. A few slight changes will allow these
optimizations to take place and get a performance boost by utilizing
SIMD instructions.

This optimization helps improve performance of vector operations in IVF
index building and updating.

* Removing const, commenting that it is only vectoirzed on ARM
2025-06-18 16:09:43 -07:00

378 lines
7.9 KiB
C

#include "postgres.h"
#include "access/generic_xlog.h"
#include "bitvec.h"
#include "catalog/pg_type.h"
#include "fmgr.h"
#include "halfutils.h"
#include "halfvec.h"
#include "ivfflat.h"
#include "storage/bufmgr.h"
/*
* Allocate a vector array
*/
VectorArray
VectorArrayInit(int maxlen, int dimensions, Size itemsize)
{
VectorArray res = palloc(sizeof(VectorArrayData));
/* Ensure items are aligned to prevent UB */
itemsize = MAXALIGN(itemsize);
res->length = 0;
res->maxlen = maxlen;
res->dim = dimensions;
res->itemsize = itemsize;
res->items = palloc_extended(maxlen * itemsize, MCXT_ALLOC_ZERO | MCXT_ALLOC_HUGE);
return res;
}
/*
* Free a vector array
*/
void
VectorArrayFree(VectorArray arr)
{
pfree(arr->items);
pfree(arr);
}
/*
* Get the number of lists in the index
*/
int
IvfflatGetLists(Relation index)
{
IvfflatOptions *opts = (IvfflatOptions *) index->rd_options;
if (opts)
return opts->lists;
return IVFFLAT_DEFAULT_LISTS;
}
/*
* Get proc
*/
FmgrInfo *
IvfflatOptionalProcInfo(Relation index, uint16 procnum)
{
if (!OidIsValid(index_getprocid(index, 1, procnum)))
return NULL;
return index_getprocinfo(index, 1, procnum);
}
/*
* Normalize value
*/
Datum
IvfflatNormValue(const IvfflatTypeInfo * typeInfo, Oid collation, Datum value)
{
return DirectFunctionCall1Coll(typeInfo->normalize, collation, value);
}
/*
* Check if non-zero norm
*/
bool
IvfflatCheckNorm(FmgrInfo *procinfo, Oid collation, Datum value)
{
return DatumGetFloat8(FunctionCall1Coll(procinfo, collation, value)) > 0;
}
/*
* New buffer
*/
Buffer
IvfflatNewBuffer(Relation index, ForkNumber forkNum)
{
Buffer buf = ReadBufferExtended(index, forkNum, P_NEW, RBM_NORMAL, NULL);
LockBuffer(buf, BUFFER_LOCK_EXCLUSIVE);
return buf;
}
/*
* Init page
*/
void
IvfflatInitPage(Buffer buf, Page page)
{
PageInit(page, BufferGetPageSize(buf), sizeof(IvfflatPageOpaqueData));
IvfflatPageGetOpaque(page)->nextblkno = InvalidBlockNumber;
IvfflatPageGetOpaque(page)->page_id = IVFFLAT_PAGE_ID;
}
/*
* Init and register page
*/
void
IvfflatInitRegisterPage(Relation index, Buffer *buf, Page *page, GenericXLogState **state)
{
*state = GenericXLogStart(index);
*page = GenericXLogRegisterBuffer(*state, *buf, GENERIC_XLOG_FULL_IMAGE);
IvfflatInitPage(*buf, *page);
}
/*
* Commit buffer
*/
void
IvfflatCommitBuffer(Buffer buf, GenericXLogState *state)
{
GenericXLogFinish(state);
UnlockReleaseBuffer(buf);
}
/*
* Add a new page
*
* The order is very important!!
*/
void
IvfflatAppendPage(Relation index, Buffer *buf, Page *page, GenericXLogState **state, ForkNumber forkNum)
{
/* Get new buffer */
Buffer newbuf = IvfflatNewBuffer(index, forkNum);
Page newpage = GenericXLogRegisterBuffer(*state, newbuf, GENERIC_XLOG_FULL_IMAGE);
/* Update the previous buffer */
IvfflatPageGetOpaque(*page)->nextblkno = BufferGetBlockNumber(newbuf);
/* Init new page */
IvfflatInitPage(newbuf, newpage);
/* Commit */
GenericXLogFinish(*state);
/* Unlock */
UnlockReleaseBuffer(*buf);
*state = GenericXLogStart(index);
*page = GenericXLogRegisterBuffer(*state, newbuf, GENERIC_XLOG_FULL_IMAGE);
*buf = newbuf;
}
/*
* Get the metapage info
*/
void
IvfflatGetMetaPageInfo(Relation index, int *lists, int *dimensions)
{
Buffer buf;
Page page;
IvfflatMetaPage metap;
buf = ReadBuffer(index, IVFFLAT_METAPAGE_BLKNO);
LockBuffer(buf, BUFFER_LOCK_SHARE);
page = BufferGetPage(buf);
metap = IvfflatPageGetMeta(page);
if (unlikely(metap->magicNumber != IVFFLAT_MAGIC_NUMBER))
elog(ERROR, "ivfflat index is not valid");
if (lists != NULL)
*lists = metap->lists;
if (dimensions != NULL)
*dimensions = metap->dimensions;
UnlockReleaseBuffer(buf);
}
/*
* Update the start or insert page of a list
*/
void
IvfflatUpdateList(Relation index, ListInfo listInfo,
BlockNumber insertPage, BlockNumber originalInsertPage,
BlockNumber startPage, ForkNumber forkNum)
{
Buffer buf;
Page page;
GenericXLogState *state;
IvfflatList list;
bool changed = false;
buf = ReadBufferExtended(index, forkNum, listInfo.blkno, RBM_NORMAL, NULL);
LockBuffer(buf, BUFFER_LOCK_EXCLUSIVE);
state = GenericXLogStart(index);
page = GenericXLogRegisterBuffer(state, buf, 0);
list = (IvfflatList) PageGetItem(page, PageGetItemId(page, listInfo.offno));
if (BlockNumberIsValid(insertPage) && insertPage != list->insertPage)
{
/* Skip update if insert page is lower than original insert page */
/* This is needed to prevent insert from overwriting vacuum */
if (!BlockNumberIsValid(originalInsertPage) || insertPage >= originalInsertPage)
{
list->insertPage = insertPage;
changed = true;
}
}
if (BlockNumberIsValid(startPage) && startPage != list->startPage)
{
list->startPage = startPage;
changed = true;
}
/* Only commit if changed */
if (changed)
IvfflatCommitBuffer(buf, state);
else
{
GenericXLogAbort(state);
UnlockReleaseBuffer(buf);
}
}
PGDLLEXPORT Datum l2_normalize(PG_FUNCTION_ARGS);
PGDLLEXPORT Datum halfvec_l2_normalize(PG_FUNCTION_ARGS);
PGDLLEXPORT Datum sparsevec_l2_normalize(PG_FUNCTION_ARGS);
static Size
VectorItemSize(int dimensions)
{
return VECTOR_SIZE(dimensions);
}
static Size
HalfvecItemSize(int dimensions)
{
return HALFVEC_SIZE(dimensions);
}
static Size
BitItemSize(int dimensions)
{
return VARBITTOTALLEN(dimensions);
}
static void
VectorUpdateCenter(Pointer v, int dimensions, float *x)
{
Vector *vec = (Vector *) v;
SET_VARSIZE(vec, VECTOR_SIZE(dimensions));
vec->dim = dimensions;
for (int k = 0; k < dimensions; k++)
vec->x[k] = x[k];
}
static void
HalfvecUpdateCenter(Pointer v, int dimensions, float *x)
{
HalfVector *vec = (HalfVector *) v;
SET_VARSIZE(vec, HALFVEC_SIZE(dimensions));
vec->dim = dimensions;
for (int k = 0; k < dimensions; k++)
vec->x[k] = Float4ToHalfUnchecked(x[k]);
}
static void
BitUpdateCenter(Pointer v, int dimensions, float *x)
{
VarBit *vec = (VarBit *) v;
unsigned char *nx = VARBITS(vec);
SET_VARSIZE(vec, VARBITTOTALLEN(dimensions));
VARBITLEN(vec) = dimensions;
for (uint32 k = 0; k < VARBITBYTES(vec); k++)
nx[k] = 0;
for (int k = 0; k < dimensions; k++)
nx[k / 8] |= (x[k] > 0.5 ? 1 : 0) << (7 - (k % 8));
}
static void
VectorSumCenter(Pointer v, float *x)
{
Vector *vec = (Vector *) v;
int dim = vec->dim;
/* Auto-vectorized */
for (int k = 0; k < dim; k++)
x[k] += vec->x[k];
}
static void
HalfvecSumCenter(Pointer v, float *x)
{
HalfVector *vec = (HalfVector *) v;
int dim = vec->dim;
/* Auto-vectorized on aarch64 */
for (int k = 0; k < dim; k++)
x[k] += HalfToFloat4(vec->x[k]);
}
static void
BitSumCenter(Pointer v, float *x)
{
VarBit *vec = (VarBit *) v;
for (int k = 0; k < VARBITLEN(vec); k++)
x[k] += (float) (((VARBITS(vec)[k / 8]) >> (7 - (k % 8))) & 0x01);
}
/*
* Get type info
*/
const IvfflatTypeInfo *
IvfflatGetTypeInfo(Relation index)
{
FmgrInfo *procinfo = IvfflatOptionalProcInfo(index, IVFFLAT_TYPE_INFO_PROC);
if (procinfo == NULL)
{
static const IvfflatTypeInfo typeInfo = {
.maxDimensions = IVFFLAT_MAX_DIM,
.normalize = l2_normalize,
.itemSize = VectorItemSize,
.updateCenter = VectorUpdateCenter,
.sumCenter = VectorSumCenter
};
return (&typeInfo);
}
else
return (const IvfflatTypeInfo *) DatumGetPointer(FunctionCall0Coll(procinfo, InvalidOid));
}
FUNCTION_PREFIX PG_FUNCTION_INFO_V1(ivfflat_halfvec_support);
Datum
ivfflat_halfvec_support(PG_FUNCTION_ARGS)
{
static const IvfflatTypeInfo typeInfo = {
.maxDimensions = IVFFLAT_MAX_DIM * 2,
.normalize = halfvec_l2_normalize,
.itemSize = HalfvecItemSize,
.updateCenter = HalfvecUpdateCenter,
.sumCenter = HalfvecSumCenter
};
PG_RETURN_POINTER(&typeInfo);
}
FUNCTION_PREFIX PG_FUNCTION_INFO_V1(ivfflat_bit_support);
Datum
ivfflat_bit_support(PG_FUNCTION_ARGS)
{
static const IvfflatTypeInfo typeInfo = {
.maxDimensions = IVFFLAT_MAX_DIM * 32,
.normalize = NULL,
.itemSize = BitItemSize,
.updateCenter = BitUpdateCenter,
.sumCenter = BitSumCenter
};
PG_RETURN_POINTER(&typeInfo);
}