mirror of
https://github.com/pgvector/pgvector.git
synced 2026-07-03 11:10:56 +08:00
Added support for halfvec to IVFFlat
This commit is contained in:
@@ -27,32 +27,6 @@
|
||||
#include <immintrin.h>
|
||||
#endif
|
||||
|
||||
/*
|
||||
* Check if half is NaN
|
||||
*/
|
||||
static inline bool
|
||||
HalfIsNan(half num)
|
||||
{
|
||||
#ifdef FLT16_SUPPORT
|
||||
return isnan(num);
|
||||
#else
|
||||
return (num & 0x7C00) == 0x7C00 && (num & 0x7FFF) != 0x7C00;
|
||||
#endif
|
||||
}
|
||||
|
||||
/*
|
||||
* Check if half is infinite
|
||||
*/
|
||||
static inline bool
|
||||
HalfIsInf(half num)
|
||||
{
|
||||
#ifdef FLT16_SUPPORT
|
||||
return isinf(num);
|
||||
#else
|
||||
return (num & 0x7FFF) == 0x7C00;
|
||||
#endif
|
||||
}
|
||||
|
||||
/*
|
||||
* Get a half from a message buffer
|
||||
*/
|
||||
|
||||
@@ -47,4 +47,24 @@ half Float4ToHalf(float num);
|
||||
half Float4ToHalfUnchecked(float num);
|
||||
int halfvec_cmp_internal(HalfVector * a, HalfVector * b);
|
||||
|
||||
static inline bool
|
||||
HalfIsNan(half num)
|
||||
{
|
||||
#ifdef FLT16_SUPPORT
|
||||
return isnan(num);
|
||||
#else
|
||||
return (num & 0x7C00) == 0x7C00 && (num & 0x7FFF) != 0x7C00;
|
||||
#endif
|
||||
}
|
||||
|
||||
static inline bool
|
||||
HalfIsInf(half num)
|
||||
{
|
||||
#ifdef FLT16_SUPPORT
|
||||
return isinf(num);
|
||||
#else
|
||||
return (num & 0x7FFF) == 0x7C00;
|
||||
#endif
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
@@ -10,12 +10,14 @@
|
||||
#include "catalog/pg_operator_d.h"
|
||||
#include "catalog/pg_type_d.h"
|
||||
#include "commands/progress.h"
|
||||
#include "halfvec.h"
|
||||
#include "ivfflat.h"
|
||||
#include "miscadmin.h"
|
||||
#include "optimizer/optimizer.h"
|
||||
#include "storage/bufmgr.h"
|
||||
#include "tcop/tcopprot.h"
|
||||
#include "utils/memutils.h"
|
||||
#include "vector.h"
|
||||
|
||||
#if PG_VERSION_NUM >= 140000
|
||||
#include "utils/backend_progress.h"
|
||||
@@ -367,7 +369,7 @@ InitBuildState(IvfflatBuildState * buildstate, Relation heap, Relation index, In
|
||||
|
||||
buildstate->slot = MakeSingleTupleTableSlot(buildstate->tupdesc, &TTSOpsVirtual);
|
||||
|
||||
buildstate->centers = VectorArrayInit(buildstate->lists, buildstate->dimensions, VECTOR_SIZE(buildstate->dimensions));
|
||||
buildstate->centers = VectorArrayInit(buildstate->lists, buildstate->dimensions, buildstate->type == IVFFLAT_TYPE_HALFVEC ? HALFVEC_SIZE(buildstate->dimensions) : VECTOR_SIZE(buildstate->dimensions));
|
||||
buildstate->listInfo = palloc(sizeof(ListInfo) * buildstate->lists);
|
||||
|
||||
buildstate->tmpCtx = AllocSetContextCreate(CurrentMemoryContext,
|
||||
|
||||
@@ -45,7 +45,8 @@
|
||||
|
||||
typedef enum IvfflatType
|
||||
{
|
||||
IVFFLAT_TYPE_VECTOR
|
||||
IVFFLAT_TYPE_VECTOR,
|
||||
IVFFLAT_TYPE_HALFVEC
|
||||
} IvfflatType;
|
||||
|
||||
/* Build phases */
|
||||
|
||||
110
src/ivfkmeans.c
110
src/ivfkmeans.c
@@ -3,10 +3,12 @@
|
||||
#include <float.h>
|
||||
#include <math.h>
|
||||
|
||||
#include "halfvec.h"
|
||||
#include "ivfflat.h"
|
||||
#include "miscadmin.h"
|
||||
#include "utils/datum.h"
|
||||
#include "utils/memutils.h"
|
||||
#include "vector.h"
|
||||
|
||||
/*
|
||||
* Initialize with kmeans++
|
||||
@@ -101,6 +103,13 @@ ApplyNorm(FmgrInfo *normprocinfo, Oid collation, Datum value, IvfflatType type)
|
||||
for (int i = 0; i < vec->dim; i++)
|
||||
vec->x[i] /= norm;
|
||||
}
|
||||
else if (type == IVFFLAT_TYPE_HALFVEC)
|
||||
{
|
||||
HalfVector *vec = DatumGetHalfVector(value);
|
||||
|
||||
for (int i = 0; i < vec->dim; i++)
|
||||
vec->x[i] = Float4ToHalfUnchecked(HalfToFloat4(vec->x[i]) / norm);
|
||||
}
|
||||
else
|
||||
elog(ERROR, "Unsupported type");
|
||||
}
|
||||
@@ -115,6 +124,15 @@ CompareVectors(const void *a, const void *b)
|
||||
return vector_cmp_internal((Vector *) a, (Vector *) b);
|
||||
}
|
||||
|
||||
/*
|
||||
* Compare half vectors
|
||||
*/
|
||||
static int
|
||||
CompareHalfVectors(const void *a, const void *b)
|
||||
{
|
||||
return halfvec_cmp_internal((HalfVector *) a, (HalfVector *) b);
|
||||
}
|
||||
|
||||
/*
|
||||
* Quick approach if we have little data
|
||||
*/
|
||||
@@ -130,6 +148,8 @@ QuickCenters(Relation index, VectorArray samples, VectorArray centers, IvfflatTy
|
||||
{
|
||||
if (type == IVFFLAT_TYPE_VECTOR)
|
||||
qsort(samples->items, samples->length, samples->itemsize, CompareVectors);
|
||||
else if (type == IVFFLAT_TYPE_HALFVEC)
|
||||
qsort(samples->items, samples->length, samples->itemsize, CompareHalfVectors);
|
||||
else
|
||||
elog(ERROR, "Unsupported type");
|
||||
|
||||
@@ -160,6 +180,16 @@ QuickCenters(Relation index, VectorArray samples, VectorArray centers, IvfflatTy
|
||||
for (int j = 0; j < dimensions; j++)
|
||||
vec->x[j] = RandomDouble();
|
||||
}
|
||||
else if (type == IVFFLAT_TYPE_HALFVEC)
|
||||
{
|
||||
HalfVector *vec = DatumGetHalfVector(center);
|
||||
|
||||
SET_VARSIZE(vec, HALFVEC_SIZE(dimensions));
|
||||
vec->dim = dimensions;
|
||||
|
||||
for (int j = 0; j < dimensions; j++)
|
||||
vec->x[j] = Float4ToHalfUnchecked((float) RandomDouble());
|
||||
}
|
||||
else
|
||||
elog(ERROR, "Unsupported type");
|
||||
|
||||
@@ -221,6 +251,7 @@ ElkanKmeans(Relation index, VectorArray samples, VectorArray centers, IvfflatTyp
|
||||
Size samplesSize = VECTOR_ARRAY_SIZE(samples->maxlen, samples->itemsize);
|
||||
Size centersSize = VECTOR_ARRAY_SIZE(centers->maxlen, centers->itemsize);
|
||||
Size newCentersSize = VECTOR_ARRAY_SIZE(numCenters, centers->itemsize);
|
||||
Size aggCentersSize = type == IVFFLAT_TYPE_VECTOR ? 0 : VECTOR_ARRAY_SIZE(numCenters, VECTOR_SIZE(dimensions));
|
||||
Size centerCountsSize = sizeof(int) * numCenters;
|
||||
Size closestCentersSize = sizeof(int) * numSamples;
|
||||
Size lowerBoundSize = sizeof(float) * numSamples * numCenters;
|
||||
@@ -230,7 +261,7 @@ ElkanKmeans(Relation index, VectorArray samples, VectorArray centers, IvfflatTyp
|
||||
Size newcdistSize = sizeof(float) * numCenters;
|
||||
|
||||
/* Calculate total size */
|
||||
Size totalSize = samplesSize + centersSize + newCentersSize + centerCountsSize + closestCentersSize + lowerBoundSize + upperBoundSize + sSize + halfcdistSize + newcdistSize;
|
||||
Size totalSize = samplesSize + centersSize + newCentersSize + aggCentersSize + centerCountsSize + closestCentersSize + lowerBoundSize + upperBoundSize + sSize + halfcdistSize + newcdistSize;
|
||||
|
||||
/* Check memory requirements */
|
||||
/* Add one to error message to ceil */
|
||||
@@ -265,7 +296,7 @@ ElkanKmeans(Relation index, VectorArray samples, VectorArray centers, IvfflatTyp
|
||||
halfcdist = palloc_extended(halfcdistSize, MCXT_ALLOC_HUGE);
|
||||
newcdist = palloc(newcdistSize);
|
||||
|
||||
aggCenters = VectorArrayInit(numCenters, dimensions, centers->itemsize);
|
||||
aggCenters = VectorArrayInit(numCenters, dimensions, VECTOR_SIZE(dimensions));
|
||||
for (int64 j = 0; j < numCenters; j++)
|
||||
{
|
||||
Vector *vec = (Vector *) VectorArrayGet(aggCenters, j);
|
||||
@@ -276,6 +307,18 @@ ElkanKmeans(Relation index, VectorArray samples, VectorArray centers, IvfflatTyp
|
||||
|
||||
if (type == IVFFLAT_TYPE_VECTOR)
|
||||
newCenters = aggCenters;
|
||||
else if (type == IVFFLAT_TYPE_HALFVEC)
|
||||
{
|
||||
newCenters = VectorArrayInit(numCenters, dimensions, centers->itemsize);
|
||||
|
||||
for (int j = 0; j < numCenters; j++)
|
||||
{
|
||||
HalfVector *vec = (HalfVector *) VectorArrayGet(newCenters, j);
|
||||
|
||||
SET_VARSIZE(vec, HALFVEC_SIZE(dimensions));
|
||||
vec->dim = dimensions;
|
||||
}
|
||||
}
|
||||
else
|
||||
elog(ERROR, "Unsupported type");
|
||||
|
||||
@@ -430,20 +473,32 @@ ElkanKmeans(Relation index, VectorArray samples, VectorArray centers, IvfflatTyp
|
||||
for (int64 j = 0; j < numSamples; j++)
|
||||
{
|
||||
int closestCenter = closestCenters[j];
|
||||
Vector *vec = (Vector *) VectorArrayGet(samples, j);
|
||||
Vector *newCenter = (Vector *) VectorArrayGet(aggCenters, closestCenter);
|
||||
Vector *aggCenter = (Vector *) VectorArrayGet(aggCenters, closestCenter);
|
||||
|
||||
/* Increment sum and count of closest center */
|
||||
for (int64 k = 0; k < dimensions; k++)
|
||||
newCenter->x[k] += vec->x[k];
|
||||
if (type == IVFFLAT_TYPE_VECTOR)
|
||||
{
|
||||
Vector *vec = (Vector *) VectorArrayGet(samples, j);
|
||||
|
||||
for (int64 k = 0; k < dimensions; k++)
|
||||
aggCenter->x[k] += vec->x[k];
|
||||
}
|
||||
else if (type == IVFFLAT_TYPE_HALFVEC)
|
||||
{
|
||||
HalfVector *vec = (HalfVector *) VectorArrayGet(samples, j);
|
||||
|
||||
for (int64 k = 0; k < dimensions; k++)
|
||||
aggCenter->x[k] += HalfToFloat4(vec->x[k]);
|
||||
}
|
||||
else
|
||||
elog(ERROR, "Unsupported type");
|
||||
|
||||
centerCounts[closestCenter] += 1;
|
||||
}
|
||||
|
||||
for (int64 j = 0; j < numCenters; j++)
|
||||
{
|
||||
Datum center = PointerGetDatum(VectorArrayGet(aggCenters, j));
|
||||
Vector *vec = DatumGetVector(center);
|
||||
Vector *vec = (Vector *) VectorArrayGet(aggCenters, j);
|
||||
|
||||
if (centerCounts[j] > 0)
|
||||
{
|
||||
@@ -464,10 +519,28 @@ ElkanKmeans(Relation index, VectorArray samples, VectorArray centers, IvfflatTyp
|
||||
for (int64 k = 0; k < dimensions; k++)
|
||||
vec->x[k] = RandomDouble();
|
||||
}
|
||||
}
|
||||
|
||||
/* Normalize if needed */
|
||||
if (normprocinfo != NULL)
|
||||
ApplyNorm(normprocinfo, collation, center, type);
|
||||
if (type == IVFFLAT_TYPE_HALFVEC)
|
||||
{
|
||||
for (int j = 0; j < numCenters; j++)
|
||||
{
|
||||
Vector *aggCenter = (Vector *) VectorArrayGet(aggCenters, j);
|
||||
HalfVector *newCenter = (HalfVector *) VectorArrayGet(newCenters, j);
|
||||
|
||||
for (int k = 0; k < dimensions; k++)
|
||||
newCenter->x[k] = Float4ToHalfUnchecked(aggCenter->x[k]);
|
||||
}
|
||||
}
|
||||
|
||||
if (normprocinfo != NULL)
|
||||
{
|
||||
for (int j = 0; j < numCenters; j++)
|
||||
{
|
||||
Datum newCenter = PointerGetDatum(VectorArrayGet(newCenters, j));
|
||||
|
||||
ApplyNorm(normprocinfo, collation, newCenter, type);
|
||||
}
|
||||
}
|
||||
|
||||
/* Step 5 */
|
||||
@@ -531,6 +604,19 @@ CheckCenters(Relation index, VectorArray centers, IvfflatType type)
|
||||
elog(ERROR, "Infinite value detected. Please report a bug.");
|
||||
}
|
||||
}
|
||||
else if (type == IVFFLAT_TYPE_HALFVEC)
|
||||
{
|
||||
HalfVector *vec = (HalfVector *) VectorArrayGet(centers, i);
|
||||
|
||||
for (int j = 0; j < vec->dim; j++)
|
||||
{
|
||||
if (HalfIsNan(vec->x[j]))
|
||||
elog(ERROR, "NaN detected. Please report a bug.");
|
||||
|
||||
if (HalfIsInf(vec->x[j]))
|
||||
elog(ERROR, "Infinite value detected. Please report a bug.");
|
||||
}
|
||||
}
|
||||
else
|
||||
elog(ERROR, "Unsupported type");
|
||||
}
|
||||
@@ -539,6 +625,8 @@ CheckCenters(Relation index, VectorArray centers, IvfflatType type)
|
||||
/* Fine to sort in-place */
|
||||
if (type == IVFFLAT_TYPE_VECTOR)
|
||||
qsort(centers->items, centers->length, centers->itemsize, CompareVectors);
|
||||
else if (type == IVFFLAT_TYPE_HALFVEC)
|
||||
qsort(centers->items, centers->length, centers->itemsize, CompareHalfVectors);
|
||||
else
|
||||
elog(ERROR, "Unsupported type");
|
||||
|
||||
|
||||
@@ -5,6 +5,7 @@
|
||||
#include "access/relscan.h"
|
||||
#include "catalog/pg_operator_d.h"
|
||||
#include "catalog/pg_type_d.h"
|
||||
#include "halfvec.h"
|
||||
#include "lib/pairingheap.h"
|
||||
#include "ivfflat.h"
|
||||
#include "miscadmin.h"
|
||||
@@ -192,6 +193,8 @@ GetScanValue(IndexScanDesc scan)
|
||||
|
||||
if (type == IVFFLAT_TYPE_VECTOR)
|
||||
value = PointerGetDatum(InitVector(so->dimensions));
|
||||
else if (type == IVFFLAT_TYPE_HALFVEC)
|
||||
value = PointerGetDatum(InitHalfVector(so->dimensions));
|
||||
else
|
||||
elog(ERROR, "Unsupported type");
|
||||
}
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
|
||||
#include "access/generic_xlog.h"
|
||||
#include "catalog/pg_type.h"
|
||||
#include "halfvec.h"
|
||||
#include "ivfflat.h"
|
||||
#include "storage/bufmgr.h"
|
||||
#include "vector.h"
|
||||
@@ -77,6 +78,8 @@ IvfflatGetType(Relation index)
|
||||
type = (Form_pg_type) GETSTRUCT(tuple);
|
||||
if (strcmp(NameStr(type->typname), "vector") == 0)
|
||||
result = IVFFLAT_TYPE_VECTOR;
|
||||
else if (strcmp(NameStr(type->typname), "halfvec") == 0)
|
||||
result = IVFFLAT_TYPE_HALFVEC;
|
||||
else
|
||||
{
|
||||
ReleaseSysCache(tuple);
|
||||
@@ -113,6 +116,16 @@ IvfflatNormValue(FmgrInfo *procinfo, Oid collation, Datum *value, IvfflatType ty
|
||||
|
||||
*value = PointerGetDatum(result);
|
||||
}
|
||||
else if (type == IVFFLAT_TYPE_HALFVEC)
|
||||
{
|
||||
HalfVector *v = DatumGetHalfVector(*value);
|
||||
HalfVector *result = InitHalfVector(v->dim);
|
||||
|
||||
for (int i = 0; i < v->dim; i++)
|
||||
result->x[i] = Float4ToHalfUnchecked(HalfToFloat4(v->x[i]) / norm);
|
||||
|
||||
*value = PointerGetDatum(result);
|
||||
}
|
||||
else
|
||||
elog(ERROR, "Unsupported type");
|
||||
|
||||
|
||||
Reference in New Issue
Block a user