Added support for halfvec to IVFFlat

This commit is contained in:
Andrew Kane
2024-04-11 19:56:39 -07:00
parent a4531ca51f
commit 8d9400bae3
17 changed files with 439 additions and 39 deletions

View File

@@ -27,32 +27,6 @@
#include <immintrin.h>
#endif
/*
* Check if half is NaN
*/
static inline bool
HalfIsNan(half num)
{
#ifdef FLT16_SUPPORT
return isnan(num);
#else
return (num & 0x7C00) == 0x7C00 && (num & 0x7FFF) != 0x7C00;
#endif
}
/*
* Check if half is infinite
*/
static inline bool
HalfIsInf(half num)
{
#ifdef FLT16_SUPPORT
return isinf(num);
#else
return (num & 0x7FFF) == 0x7C00;
#endif
}
/*
* Get a half from a message buffer
*/

View File

@@ -47,4 +47,24 @@ half Float4ToHalf(float num);
half Float4ToHalfUnchecked(float num);
int halfvec_cmp_internal(HalfVector * a, HalfVector * b);
static inline bool
HalfIsNan(half num)
{
#ifdef FLT16_SUPPORT
return isnan(num);
#else
return (num & 0x7C00) == 0x7C00 && (num & 0x7FFF) != 0x7C00;
#endif
}
static inline bool
HalfIsInf(half num)
{
#ifdef FLT16_SUPPORT
return isinf(num);
#else
return (num & 0x7FFF) == 0x7C00;
#endif
}
#endif

View File

@@ -10,12 +10,14 @@
#include "catalog/pg_operator_d.h"
#include "catalog/pg_type_d.h"
#include "commands/progress.h"
#include "halfvec.h"
#include "ivfflat.h"
#include "miscadmin.h"
#include "optimizer/optimizer.h"
#include "storage/bufmgr.h"
#include "tcop/tcopprot.h"
#include "utils/memutils.h"
#include "vector.h"
#if PG_VERSION_NUM >= 140000
#include "utils/backend_progress.h"
@@ -367,7 +369,7 @@ InitBuildState(IvfflatBuildState * buildstate, Relation heap, Relation index, In
buildstate->slot = MakeSingleTupleTableSlot(buildstate->tupdesc, &TTSOpsVirtual);
buildstate->centers = VectorArrayInit(buildstate->lists, buildstate->dimensions, VECTOR_SIZE(buildstate->dimensions));
buildstate->centers = VectorArrayInit(buildstate->lists, buildstate->dimensions, buildstate->type == IVFFLAT_TYPE_HALFVEC ? HALFVEC_SIZE(buildstate->dimensions) : VECTOR_SIZE(buildstate->dimensions));
buildstate->listInfo = palloc(sizeof(ListInfo) * buildstate->lists);
buildstate->tmpCtx = AllocSetContextCreate(CurrentMemoryContext,

View File

@@ -45,7 +45,8 @@
typedef enum IvfflatType
{
IVFFLAT_TYPE_VECTOR
IVFFLAT_TYPE_VECTOR,
IVFFLAT_TYPE_HALFVEC
} IvfflatType;
/* Build phases */

View File

@@ -3,10 +3,12 @@
#include <float.h>
#include <math.h>
#include "halfvec.h"
#include "ivfflat.h"
#include "miscadmin.h"
#include "utils/datum.h"
#include "utils/memutils.h"
#include "vector.h"
/*
* Initialize with kmeans++
@@ -101,6 +103,13 @@ ApplyNorm(FmgrInfo *normprocinfo, Oid collation, Datum value, IvfflatType type)
for (int i = 0; i < vec->dim; i++)
vec->x[i] /= norm;
}
else if (type == IVFFLAT_TYPE_HALFVEC)
{
HalfVector *vec = DatumGetHalfVector(value);
for (int i = 0; i < vec->dim; i++)
vec->x[i] = Float4ToHalfUnchecked(HalfToFloat4(vec->x[i]) / norm);
}
else
elog(ERROR, "Unsupported type");
}
@@ -115,6 +124,15 @@ CompareVectors(const void *a, const void *b)
return vector_cmp_internal((Vector *) a, (Vector *) b);
}
/*
* Compare half vectors
*/
static int
CompareHalfVectors(const void *a, const void *b)
{
return halfvec_cmp_internal((HalfVector *) a, (HalfVector *) b);
}
/*
* Quick approach if we have little data
*/
@@ -130,6 +148,8 @@ QuickCenters(Relation index, VectorArray samples, VectorArray centers, IvfflatTy
{
if (type == IVFFLAT_TYPE_VECTOR)
qsort(samples->items, samples->length, samples->itemsize, CompareVectors);
else if (type == IVFFLAT_TYPE_HALFVEC)
qsort(samples->items, samples->length, samples->itemsize, CompareHalfVectors);
else
elog(ERROR, "Unsupported type");
@@ -160,6 +180,16 @@ QuickCenters(Relation index, VectorArray samples, VectorArray centers, IvfflatTy
for (int j = 0; j < dimensions; j++)
vec->x[j] = RandomDouble();
}
else if (type == IVFFLAT_TYPE_HALFVEC)
{
HalfVector *vec = DatumGetHalfVector(center);
SET_VARSIZE(vec, HALFVEC_SIZE(dimensions));
vec->dim = dimensions;
for (int j = 0; j < dimensions; j++)
vec->x[j] = Float4ToHalfUnchecked((float) RandomDouble());
}
else
elog(ERROR, "Unsupported type");
@@ -221,6 +251,7 @@ ElkanKmeans(Relation index, VectorArray samples, VectorArray centers, IvfflatTyp
Size samplesSize = VECTOR_ARRAY_SIZE(samples->maxlen, samples->itemsize);
Size centersSize = VECTOR_ARRAY_SIZE(centers->maxlen, centers->itemsize);
Size newCentersSize = VECTOR_ARRAY_SIZE(numCenters, centers->itemsize);
Size aggCentersSize = type == IVFFLAT_TYPE_VECTOR ? 0 : VECTOR_ARRAY_SIZE(numCenters, VECTOR_SIZE(dimensions));
Size centerCountsSize = sizeof(int) * numCenters;
Size closestCentersSize = sizeof(int) * numSamples;
Size lowerBoundSize = sizeof(float) * numSamples * numCenters;
@@ -230,7 +261,7 @@ ElkanKmeans(Relation index, VectorArray samples, VectorArray centers, IvfflatTyp
Size newcdistSize = sizeof(float) * numCenters;
/* Calculate total size */
Size totalSize = samplesSize + centersSize + newCentersSize + centerCountsSize + closestCentersSize + lowerBoundSize + upperBoundSize + sSize + halfcdistSize + newcdistSize;
Size totalSize = samplesSize + centersSize + newCentersSize + aggCentersSize + centerCountsSize + closestCentersSize + lowerBoundSize + upperBoundSize + sSize + halfcdistSize + newcdistSize;
/* Check memory requirements */
/* Add one to error message to ceil */
@@ -265,7 +296,7 @@ ElkanKmeans(Relation index, VectorArray samples, VectorArray centers, IvfflatTyp
halfcdist = palloc_extended(halfcdistSize, MCXT_ALLOC_HUGE);
newcdist = palloc(newcdistSize);
aggCenters = VectorArrayInit(numCenters, dimensions, centers->itemsize);
aggCenters = VectorArrayInit(numCenters, dimensions, VECTOR_SIZE(dimensions));
for (int64 j = 0; j < numCenters; j++)
{
Vector *vec = (Vector *) VectorArrayGet(aggCenters, j);
@@ -276,6 +307,18 @@ ElkanKmeans(Relation index, VectorArray samples, VectorArray centers, IvfflatTyp
if (type == IVFFLAT_TYPE_VECTOR)
newCenters = aggCenters;
else if (type == IVFFLAT_TYPE_HALFVEC)
{
newCenters = VectorArrayInit(numCenters, dimensions, centers->itemsize);
for (int j = 0; j < numCenters; j++)
{
HalfVector *vec = (HalfVector *) VectorArrayGet(newCenters, j);
SET_VARSIZE(vec, HALFVEC_SIZE(dimensions));
vec->dim = dimensions;
}
}
else
elog(ERROR, "Unsupported type");
@@ -430,20 +473,32 @@ ElkanKmeans(Relation index, VectorArray samples, VectorArray centers, IvfflatTyp
for (int64 j = 0; j < numSamples; j++)
{
int closestCenter = closestCenters[j];
Vector *vec = (Vector *) VectorArrayGet(samples, j);
Vector *newCenter = (Vector *) VectorArrayGet(aggCenters, closestCenter);
Vector *aggCenter = (Vector *) VectorArrayGet(aggCenters, closestCenter);
/* Increment sum and count of closest center */
for (int64 k = 0; k < dimensions; k++)
newCenter->x[k] += vec->x[k];
if (type == IVFFLAT_TYPE_VECTOR)
{
Vector *vec = (Vector *) VectorArrayGet(samples, j);
for (int64 k = 0; k < dimensions; k++)
aggCenter->x[k] += vec->x[k];
}
else if (type == IVFFLAT_TYPE_HALFVEC)
{
HalfVector *vec = (HalfVector *) VectorArrayGet(samples, j);
for (int64 k = 0; k < dimensions; k++)
aggCenter->x[k] += HalfToFloat4(vec->x[k]);
}
else
elog(ERROR, "Unsupported type");
centerCounts[closestCenter] += 1;
}
for (int64 j = 0; j < numCenters; j++)
{
Datum center = PointerGetDatum(VectorArrayGet(aggCenters, j));
Vector *vec = DatumGetVector(center);
Vector *vec = (Vector *) VectorArrayGet(aggCenters, j);
if (centerCounts[j] > 0)
{
@@ -464,10 +519,28 @@ ElkanKmeans(Relation index, VectorArray samples, VectorArray centers, IvfflatTyp
for (int64 k = 0; k < dimensions; k++)
vec->x[k] = RandomDouble();
}
}
/* Normalize if needed */
if (normprocinfo != NULL)
ApplyNorm(normprocinfo, collation, center, type);
if (type == IVFFLAT_TYPE_HALFVEC)
{
for (int j = 0; j < numCenters; j++)
{
Vector *aggCenter = (Vector *) VectorArrayGet(aggCenters, j);
HalfVector *newCenter = (HalfVector *) VectorArrayGet(newCenters, j);
for (int k = 0; k < dimensions; k++)
newCenter->x[k] = Float4ToHalfUnchecked(aggCenter->x[k]);
}
}
if (normprocinfo != NULL)
{
for (int j = 0; j < numCenters; j++)
{
Datum newCenter = PointerGetDatum(VectorArrayGet(newCenters, j));
ApplyNorm(normprocinfo, collation, newCenter, type);
}
}
/* Step 5 */
@@ -531,6 +604,19 @@ CheckCenters(Relation index, VectorArray centers, IvfflatType type)
elog(ERROR, "Infinite value detected. Please report a bug.");
}
}
else if (type == IVFFLAT_TYPE_HALFVEC)
{
HalfVector *vec = (HalfVector *) VectorArrayGet(centers, i);
for (int j = 0; j < vec->dim; j++)
{
if (HalfIsNan(vec->x[j]))
elog(ERROR, "NaN detected. Please report a bug.");
if (HalfIsInf(vec->x[j]))
elog(ERROR, "Infinite value detected. Please report a bug.");
}
}
else
elog(ERROR, "Unsupported type");
}
@@ -539,6 +625,8 @@ CheckCenters(Relation index, VectorArray centers, IvfflatType type)
/* Fine to sort in-place */
if (type == IVFFLAT_TYPE_VECTOR)
qsort(centers->items, centers->length, centers->itemsize, CompareVectors);
else if (type == IVFFLAT_TYPE_HALFVEC)
qsort(centers->items, centers->length, centers->itemsize, CompareHalfVectors);
else
elog(ERROR, "Unsupported type");

View File

@@ -5,6 +5,7 @@
#include "access/relscan.h"
#include "catalog/pg_operator_d.h"
#include "catalog/pg_type_d.h"
#include "halfvec.h"
#include "lib/pairingheap.h"
#include "ivfflat.h"
#include "miscadmin.h"
@@ -192,6 +193,8 @@ GetScanValue(IndexScanDesc scan)
if (type == IVFFLAT_TYPE_VECTOR)
value = PointerGetDatum(InitVector(so->dimensions));
else if (type == IVFFLAT_TYPE_HALFVEC)
value = PointerGetDatum(InitHalfVector(so->dimensions));
else
elog(ERROR, "Unsupported type");
}

View File

@@ -2,6 +2,7 @@
#include "access/generic_xlog.h"
#include "catalog/pg_type.h"
#include "halfvec.h"
#include "ivfflat.h"
#include "storage/bufmgr.h"
#include "vector.h"
@@ -77,6 +78,8 @@ IvfflatGetType(Relation index)
type = (Form_pg_type) GETSTRUCT(tuple);
if (strcmp(NameStr(type->typname), "vector") == 0)
result = IVFFLAT_TYPE_VECTOR;
else if (strcmp(NameStr(type->typname), "halfvec") == 0)
result = IVFFLAT_TYPE_HALFVEC;
else
{
ReleaseSysCache(tuple);
@@ -113,6 +116,16 @@ IvfflatNormValue(FmgrInfo *procinfo, Oid collation, Datum *value, IvfflatType ty
*value = PointerGetDatum(result);
}
else if (type == IVFFLAT_TYPE_HALFVEC)
{
HalfVector *v = DatumGetHalfVector(*value);
HalfVector *result = InitHalfVector(v->dim);
for (int i = 0; i < v->dim; i++)
result->x[i] = Float4ToHalfUnchecked(HalfToFloat4(v->x[i]) / norm);
*value = PointerGetDatum(result);
}
else
elog(ERROR, "Unsupported type");