mirror of
https://github.com/pgvector/pgvector.git
synced 2026-07-03 19:20:56 +08:00
Added halfvec type
This commit is contained in:
901
src/halfvec.c
Normal file
901
src/halfvec.c
Normal file
@@ -0,0 +1,901 @@
|
||||
#include "postgres.h"
|
||||
|
||||
#include <math.h>
|
||||
|
||||
#include "catalog/pg_type.h"
|
||||
#include "common/shortest_dec.h"
|
||||
#include "fmgr.h"
|
||||
#include "halfvec.h"
|
||||
#include "lib/stringinfo.h"
|
||||
#include "libpq/pqformat.h"
|
||||
#include "port.h" /* for strtof() */
|
||||
#include "utils/array.h"
|
||||
#include "utils/builtins.h"
|
||||
#include "utils/float.h"
|
||||
#include "utils/lsyscache.h"
|
||||
#include "utils/numeric.h"
|
||||
|
||||
/*
|
||||
* Check if half is NaN
|
||||
*/
|
||||
static inline bool
|
||||
HalfIsNan(half num)
|
||||
{
|
||||
#ifdef FLT16_SUPPORT
|
||||
return isnan(num);
|
||||
#else
|
||||
return (num & 0x7C00) == 0x7C00 && (num & 0x7FFF) != 0x7C00;
|
||||
#endif
|
||||
}
|
||||
|
||||
/*
|
||||
* Check if half is infinite
|
||||
*/
|
||||
static inline bool
|
||||
HalfIsInf(half num)
|
||||
{
|
||||
#ifdef FLT16_SUPPORT
|
||||
return isinf(num);
|
||||
#else
|
||||
return (num & 0x7FFF) == 0x7C00;
|
||||
#endif
|
||||
}
|
||||
|
||||
/*
|
||||
* Check if half is zero
|
||||
*/
|
||||
static inline bool
|
||||
HalfIsZero(half num)
|
||||
{
|
||||
#ifdef FLT16_SUPPORT
|
||||
return num == 0;
|
||||
#else
|
||||
return (num & 0x7FFF) == 0x0000;
|
||||
#endif
|
||||
}
|
||||
|
||||
/*
|
||||
* Get a half from a message buffer
|
||||
*/
|
||||
static half
|
||||
pq_getmsghalf(StringInfo msg)
|
||||
{
|
||||
union
|
||||
{
|
||||
half h;
|
||||
uint16 i;
|
||||
} swap;
|
||||
|
||||
swap.i = pq_getmsgint(msg, 2);
|
||||
return swap.h;
|
||||
}
|
||||
|
||||
/*
|
||||
* Append a half to a StringInfo buffer
|
||||
*/
|
||||
static void
|
||||
pq_sendhalf(StringInfo buf, half h)
|
||||
{
|
||||
union
|
||||
{
|
||||
half h;
|
||||
uint16 i;
|
||||
} swap;
|
||||
|
||||
swap.h = h;
|
||||
pq_sendint16(buf, swap.i);
|
||||
}
|
||||
|
||||
/*
|
||||
* Convert a half to a float4
|
||||
*/
|
||||
static float
|
||||
HalfToFloat4(half num)
|
||||
{
|
||||
#ifdef FLT16_SUPPORT
|
||||
return (float) num;
|
||||
#else
|
||||
/* TODO Improve performance */
|
||||
|
||||
/* Assumes same endianness for floats and integers */
|
||||
/* TODO Use union to swap */
|
||||
uint16 bin = *((uint16 *) &num);
|
||||
uint32 exponent = (bin & 0x7C00) >> 10;
|
||||
uint32 mantissa = bin & 0x03FF;
|
||||
|
||||
/* Sign */
|
||||
uint32 result = (bin & 0x8000) << 16;
|
||||
|
||||
if (exponent == 31)
|
||||
{
|
||||
if (mantissa == 0)
|
||||
{
|
||||
/* Infinite */
|
||||
result |= 0x7F800000;
|
||||
}
|
||||
else
|
||||
{
|
||||
/* NaN */
|
||||
result |= 0x7FC00000;
|
||||
result |= mantissa << 13;
|
||||
}
|
||||
}
|
||||
else if (exponent == 0)
|
||||
{
|
||||
/* Subnormal */
|
||||
if (mantissa != 0)
|
||||
{
|
||||
exponent = -14;
|
||||
|
||||
for (int i = 0; i < 10; i++)
|
||||
{
|
||||
mantissa <<= 1;
|
||||
exponent -= 1;
|
||||
|
||||
if ((mantissa >> 10) % 2 == 1)
|
||||
{
|
||||
mantissa &= 0x03ff;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
result |= (exponent + 127) << 23;
|
||||
result |= mantissa << 13;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
/* Normal */
|
||||
result |= (exponent - 15 + 127) << 23;
|
||||
result |= mantissa << 13;
|
||||
}
|
||||
|
||||
/* TODO Use union to swap */
|
||||
return *((float *) &result);
|
||||
#endif
|
||||
}
|
||||
|
||||
/*
|
||||
* Convert a float4 to a half
|
||||
*/
|
||||
static half
|
||||
Float4ToHalfUnchecked(float num)
|
||||
{
|
||||
#ifdef FLT16_SUPPORT
|
||||
return (_Float16) num;
|
||||
#else
|
||||
/* TODO Improve performance */
|
||||
|
||||
/* Assumes same endianness for floats and integers */
|
||||
/* TODO Use union to swap */
|
||||
uint32 bin = *((uint32 *) &num);
|
||||
int exponent = (bin & 0x7F800000) >> 23;
|
||||
int mantissa = bin & 0x007FFFFF;
|
||||
|
||||
/* Sign */
|
||||
uint16 result = (bin & 0x80000000) >> 16;
|
||||
|
||||
if (isinf(num))
|
||||
{
|
||||
/* Infinite */
|
||||
result |= 0x7C00;
|
||||
}
|
||||
else if (isnan(num))
|
||||
{
|
||||
/* NaN */
|
||||
result |= 0x7E00;
|
||||
result |= mantissa >> 13;
|
||||
}
|
||||
else if (exponent > 98)
|
||||
{
|
||||
int m;
|
||||
int gr;
|
||||
int s;
|
||||
|
||||
exponent -= 127;
|
||||
s = mantissa & 0x00000FFF;
|
||||
|
||||
/* Subnormal */
|
||||
if (exponent < -14)
|
||||
{
|
||||
int diff = -exponent - 14;
|
||||
|
||||
mantissa >>= diff;
|
||||
mantissa += 1 << (23 - diff);
|
||||
s |= mantissa & 0x00000FFF;
|
||||
}
|
||||
|
||||
m = mantissa >> 13;
|
||||
|
||||
/* Round */
|
||||
gr = (mantissa >> 12) % 4;
|
||||
if (gr == 3 || (gr == 1 && s != 0))
|
||||
m += 1;
|
||||
|
||||
if (m == 1024)
|
||||
{
|
||||
m = 0;
|
||||
exponent += 1;
|
||||
}
|
||||
|
||||
if (exponent > 15)
|
||||
{
|
||||
/* Infinite */
|
||||
result |= 0x7C00;
|
||||
}
|
||||
else
|
||||
{
|
||||
if (exponent >= -14)
|
||||
result |= (exponent + 15) << 10;
|
||||
|
||||
result |= m;
|
||||
}
|
||||
}
|
||||
|
||||
/* TODO Use union to swap */
|
||||
return *((half *) & result);
|
||||
#endif
|
||||
}
|
||||
|
||||
/*
|
||||
* Convert a float4 to a half
|
||||
*/
|
||||
static half
|
||||
Float4ToHalf(float num)
|
||||
{
|
||||
half result = Float4ToHalfUnchecked(num);
|
||||
|
||||
if (unlikely(HalfIsInf(result)) && !isinf(num))
|
||||
float_overflow_error();
|
||||
if (unlikely(HalfIsZero(result)) && num != 0.0)
|
||||
float_underflow_error();
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
/*
|
||||
* Ensure same dimensions
|
||||
*/
|
||||
static inline void
|
||||
CheckDims(HalfVector * a, HalfVector * b)
|
||||
{
|
||||
if (a->dim != b->dim)
|
||||
ereport(ERROR,
|
||||
(errcode(ERRCODE_DATA_EXCEPTION),
|
||||
errmsg("different halfvec dimensions %d and %d", a->dim, b->dim)));
|
||||
}
|
||||
|
||||
/*
|
||||
* Ensure expected dimensions
|
||||
*/
|
||||
static inline void
|
||||
CheckExpectedDim(int32 typmod, int dim)
|
||||
{
|
||||
if (typmod != -1 && typmod != dim)
|
||||
ereport(ERROR,
|
||||
(errcode(ERRCODE_DATA_EXCEPTION),
|
||||
errmsg("expected %d dimensions, not %d", typmod, dim)));
|
||||
}
|
||||
|
||||
/*
|
||||
* Ensure valid dimensions
|
||||
*/
|
||||
static inline void
|
||||
CheckDim(int dim)
|
||||
{
|
||||
if (dim < 1)
|
||||
ereport(ERROR,
|
||||
(errcode(ERRCODE_DATA_EXCEPTION),
|
||||
errmsg("halfvec must have at least 1 dimension")));
|
||||
|
||||
if (dim > HALFVEC_MAX_DIM)
|
||||
ereport(ERROR,
|
||||
(errcode(ERRCODE_PROGRAM_LIMIT_EXCEEDED),
|
||||
errmsg("halfvec cannot have more than %d dimensions", HALFVEC_MAX_DIM)));
|
||||
}
|
||||
|
||||
/*
|
||||
* Ensure finite element
|
||||
*/
|
||||
static inline void
|
||||
CheckElement(half value)
|
||||
{
|
||||
if (HalfIsNan(value))
|
||||
ereport(ERROR,
|
||||
(errcode(ERRCODE_DATA_EXCEPTION),
|
||||
errmsg("NaN not allowed in halfvec")));
|
||||
|
||||
if (HalfIsInf(value))
|
||||
ereport(ERROR,
|
||||
(errcode(ERRCODE_DATA_EXCEPTION),
|
||||
errmsg("infinite value not allowed in halfvec")));
|
||||
}
|
||||
|
||||
/*
|
||||
* Allocate and initialize a new half vector
|
||||
*/
|
||||
HalfVector *
|
||||
InitHalfVector(int dim)
|
||||
{
|
||||
HalfVector *result;
|
||||
int size;
|
||||
|
||||
size = HALFVEC_SIZE(dim);
|
||||
result = (HalfVector *) palloc0(size);
|
||||
SET_VARSIZE(result, size);
|
||||
result->dim = dim;
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
/*
|
||||
* Check for whitespace, since array_isspace() is static
|
||||
*/
|
||||
static inline bool
|
||||
halfvec_isspace(char ch)
|
||||
{
|
||||
if (ch == ' ' ||
|
||||
ch == '\t' ||
|
||||
ch == '\n' ||
|
||||
ch == '\r' ||
|
||||
ch == '\v' ||
|
||||
ch == '\f')
|
||||
return true;
|
||||
return false;
|
||||
}
|
||||
|
||||
#if PG_VERSION_NUM < 120003
|
||||
static pg_noinline void
|
||||
float_overflow_error(void)
|
||||
{
|
||||
ereport(ERROR,
|
||||
(errcode(ERRCODE_NUMERIC_VALUE_OUT_OF_RANGE),
|
||||
errmsg("value out of range: overflow")));
|
||||
}
|
||||
|
||||
static pg_noinline void
|
||||
float_underflow_error(void)
|
||||
{
|
||||
ereport(ERROR,
|
||||
(errcode(ERRCODE_NUMERIC_VALUE_OUT_OF_RANGE),
|
||||
errmsg("value out of range: underflow")));
|
||||
}
|
||||
#endif
|
||||
|
||||
/*
|
||||
* Convert textual representation to internal representation
|
||||
*/
|
||||
PGDLLEXPORT PG_FUNCTION_INFO_V1(halfvec_in);
|
||||
Datum
|
||||
halfvec_in(PG_FUNCTION_ARGS)
|
||||
{
|
||||
char *lit = PG_GETARG_CSTRING(0);
|
||||
int32 typmod = PG_GETARG_INT32(2);
|
||||
half x[HALFVEC_MAX_DIM];
|
||||
int dim = 0;
|
||||
char *pt;
|
||||
char *stringEnd;
|
||||
HalfVector *result;
|
||||
char *litcopy = pstrdup(lit);
|
||||
char *str = litcopy;
|
||||
|
||||
while (halfvec_isspace(*str))
|
||||
str++;
|
||||
|
||||
if (*str != '[')
|
||||
ereport(ERROR,
|
||||
(errcode(ERRCODE_INVALID_TEXT_REPRESENTATION),
|
||||
errmsg("malformed halfvec literal: \"%s\"", lit),
|
||||
errdetail("Vector contents must start with \"[\".")));
|
||||
|
||||
str++;
|
||||
pt = strtok(str, ",");
|
||||
stringEnd = pt;
|
||||
|
||||
while (pt != NULL && *stringEnd != ']')
|
||||
{
|
||||
if (dim == HALFVEC_MAX_DIM)
|
||||
ereport(ERROR,
|
||||
(errcode(ERRCODE_PROGRAM_LIMIT_EXCEEDED),
|
||||
errmsg("halfvec cannot have more than %d dimensions", HALFVEC_MAX_DIM)));
|
||||
|
||||
while (halfvec_isspace(*pt))
|
||||
pt++;
|
||||
|
||||
/* Check for empty string like float4in */
|
||||
if (*pt == '\0')
|
||||
ereport(ERROR,
|
||||
(errcode(ERRCODE_INVALID_TEXT_REPRESENTATION),
|
||||
errmsg("invalid input syntax for type halfvec: \"%s\"", lit)));
|
||||
|
||||
/* Use strtof like float4in to avoid a double-rounding problem */
|
||||
x[dim] = Float4ToHalf(strtof(pt, &stringEnd));
|
||||
CheckElement(x[dim]);
|
||||
dim++;
|
||||
|
||||
if (stringEnd == pt)
|
||||
ereport(ERROR,
|
||||
(errcode(ERRCODE_INVALID_TEXT_REPRESENTATION),
|
||||
errmsg("invalid input syntax for type halfvec: \"%s\"", lit)));
|
||||
|
||||
while (halfvec_isspace(*stringEnd))
|
||||
stringEnd++;
|
||||
|
||||
if (*stringEnd != '\0' && *stringEnd != ']')
|
||||
ereport(ERROR,
|
||||
(errcode(ERRCODE_INVALID_TEXT_REPRESENTATION),
|
||||
errmsg("invalid input syntax for type halfvec: \"%s\"", lit)));
|
||||
|
||||
pt = strtok(NULL, ",");
|
||||
}
|
||||
|
||||
if (stringEnd == NULL || *stringEnd != ']')
|
||||
ereport(ERROR,
|
||||
(errcode(ERRCODE_INVALID_TEXT_REPRESENTATION),
|
||||
errmsg("malformed halfvec literal: \"%s\"", lit),
|
||||
errdetail("Unexpected end of input.")));
|
||||
|
||||
stringEnd++;
|
||||
|
||||
/* Only whitespace is allowed after the closing brace */
|
||||
while (halfvec_isspace(*stringEnd))
|
||||
stringEnd++;
|
||||
|
||||
if (*stringEnd != '\0')
|
||||
ereport(ERROR,
|
||||
(errcode(ERRCODE_INVALID_TEXT_REPRESENTATION),
|
||||
errmsg("malformed halfvec literal: \"%s\"", lit),
|
||||
errdetail("Junk after closing right brace.")));
|
||||
|
||||
/* Ensure no consecutive delimiters since strtok skips */
|
||||
for (pt = lit + 1; *pt != '\0'; pt++)
|
||||
{
|
||||
if (pt[-1] == ',' && *pt == ',')
|
||||
ereport(ERROR,
|
||||
(errcode(ERRCODE_INVALID_TEXT_REPRESENTATION),
|
||||
errmsg("malformed halfvec literal: \"%s\"", lit)));
|
||||
}
|
||||
|
||||
if (dim < 1)
|
||||
ereport(ERROR,
|
||||
(errcode(ERRCODE_DATA_EXCEPTION),
|
||||
errmsg("halfvec must have at least 1 dimension")));
|
||||
|
||||
pfree(litcopy);
|
||||
|
||||
CheckExpectedDim(typmod, dim);
|
||||
|
||||
result = InitHalfVector(dim);
|
||||
for (int i = 0; i < dim; i++)
|
||||
result->x[i] = x[i];
|
||||
|
||||
PG_RETURN_POINTER(result);
|
||||
}
|
||||
|
||||
/*
|
||||
* Convert internal representation to textual representation
|
||||
*/
|
||||
PGDLLEXPORT PG_FUNCTION_INFO_V1(halfvec_out);
|
||||
Datum
|
||||
halfvec_out(PG_FUNCTION_ARGS)
|
||||
{
|
||||
HalfVector *vector = PG_GETARG_HALFVEC_P(0);
|
||||
int dim = vector->dim;
|
||||
char *buf;
|
||||
char *ptr;
|
||||
int n;
|
||||
|
||||
/*
|
||||
* Need:
|
||||
*
|
||||
* dim * (FLOAT_SHORTEST_DECIMAL_LEN - 1) bytes for
|
||||
* float_to_shortest_decimal_bufn
|
||||
*
|
||||
* dim - 1 bytes for separator
|
||||
*
|
||||
* 3 bytes for [, ], and \0
|
||||
*/
|
||||
buf = (char *) palloc(FLOAT_SHORTEST_DECIMAL_LEN * dim + 2);
|
||||
ptr = buf;
|
||||
|
||||
*ptr = '[';
|
||||
ptr++;
|
||||
for (int i = 0; i < dim; i++)
|
||||
{
|
||||
if (i > 0)
|
||||
{
|
||||
*ptr = ',';
|
||||
ptr++;
|
||||
}
|
||||
|
||||
n = float_to_shortest_decimal_bufn(HalfToFloat4(vector->x[i]), ptr);
|
||||
ptr += n;
|
||||
}
|
||||
*ptr = ']';
|
||||
ptr++;
|
||||
*ptr = '\0';
|
||||
|
||||
PG_FREE_IF_COPY(vector, 0);
|
||||
PG_RETURN_CSTRING(buf);
|
||||
}
|
||||
|
||||
/*
|
||||
* Convert type modifier
|
||||
*/
|
||||
PGDLLEXPORT PG_FUNCTION_INFO_V1(halfvec_typmod_in);
|
||||
Datum
|
||||
halfvec_typmod_in(PG_FUNCTION_ARGS)
|
||||
{
|
||||
ArrayType *ta = PG_GETARG_ARRAYTYPE_P(0);
|
||||
int32 *tl;
|
||||
int n;
|
||||
|
||||
tl = ArrayGetIntegerTypmods(ta, &n);
|
||||
|
||||
if (n != 1)
|
||||
ereport(ERROR,
|
||||
(errcode(ERRCODE_INVALID_PARAMETER_VALUE),
|
||||
errmsg("invalid type modifier")));
|
||||
|
||||
if (*tl < 1)
|
||||
ereport(ERROR,
|
||||
(errcode(ERRCODE_INVALID_PARAMETER_VALUE),
|
||||
errmsg("dimensions for type halfvec must be at least 1")));
|
||||
|
||||
if (*tl > HALFVEC_MAX_DIM)
|
||||
ereport(ERROR,
|
||||
(errcode(ERRCODE_INVALID_PARAMETER_VALUE),
|
||||
errmsg("dimensions for type halfvec cannot exceed %d", HALFVEC_MAX_DIM)));
|
||||
|
||||
PG_RETURN_INT32(*tl);
|
||||
}
|
||||
|
||||
/*
|
||||
* Convert external binary representation to internal representation
|
||||
*/
|
||||
PGDLLEXPORT PG_FUNCTION_INFO_V1(halfvec_recv);
|
||||
Datum
|
||||
halfvec_recv(PG_FUNCTION_ARGS)
|
||||
{
|
||||
StringInfo buf = (StringInfo) PG_GETARG_POINTER(0);
|
||||
int32 typmod = PG_GETARG_INT32(2);
|
||||
HalfVector *result;
|
||||
int16 dim;
|
||||
int16 unused;
|
||||
|
||||
dim = pq_getmsgint(buf, sizeof(int16));
|
||||
unused = pq_getmsgint(buf, sizeof(int16));
|
||||
|
||||
CheckDim(dim);
|
||||
CheckExpectedDim(typmod, dim);
|
||||
|
||||
if (unused != 0)
|
||||
ereport(ERROR,
|
||||
(errcode(ERRCODE_DATA_EXCEPTION),
|
||||
errmsg("expected unused to be 0, not %d", unused)));
|
||||
|
||||
result = InitHalfVector(dim);
|
||||
for (int i = 0; i < dim; i++)
|
||||
{
|
||||
result->x[i] = pq_getmsghalf(buf);
|
||||
CheckElement(result->x[i]);
|
||||
}
|
||||
|
||||
PG_RETURN_POINTER(result);
|
||||
}
|
||||
|
||||
/*
|
||||
* Convert internal representation to the external binary representation
|
||||
*/
|
||||
PGDLLEXPORT PG_FUNCTION_INFO_V1(halfvec_send);
|
||||
Datum
|
||||
halfvec_send(PG_FUNCTION_ARGS)
|
||||
{
|
||||
HalfVector *vec = PG_GETARG_HALFVEC_P(0);
|
||||
StringInfoData buf;
|
||||
|
||||
pq_begintypsend(&buf);
|
||||
pq_sendint(&buf, vec->dim, sizeof(int16));
|
||||
pq_sendint(&buf, vec->unused, sizeof(int16));
|
||||
for (int i = 0; i < vec->dim; i++)
|
||||
pq_sendhalf(&buf, vec->x[i]);
|
||||
|
||||
PG_RETURN_BYTEA_P(pq_endtypsend(&buf));
|
||||
}
|
||||
|
||||
/*
|
||||
* Convert half vector to half vector
|
||||
* This is needed to check the type modifier
|
||||
*/
|
||||
PGDLLEXPORT PG_FUNCTION_INFO_V1(halfvec);
|
||||
Datum
|
||||
halfvec(PG_FUNCTION_ARGS)
|
||||
{
|
||||
HalfVector *vec = PG_GETARG_HALFVEC_P(0);
|
||||
int32 typmod = PG_GETARG_INT32(1);
|
||||
|
||||
CheckExpectedDim(typmod, vec->dim);
|
||||
|
||||
PG_RETURN_POINTER(vec);
|
||||
}
|
||||
|
||||
/*
|
||||
* Convert array to half vector
|
||||
*/
|
||||
PGDLLEXPORT PG_FUNCTION_INFO_V1(array_to_halfvec);
|
||||
Datum
|
||||
array_to_halfvec(PG_FUNCTION_ARGS)
|
||||
{
|
||||
ArrayType *array = PG_GETARG_ARRAYTYPE_P(0);
|
||||
int32 typmod = PG_GETARG_INT32(1);
|
||||
HalfVector *result;
|
||||
int16 typlen;
|
||||
bool typbyval;
|
||||
char typalign;
|
||||
Datum *elemsp;
|
||||
int nelemsp;
|
||||
|
||||
if (ARR_NDIM(array) > 1)
|
||||
ereport(ERROR,
|
||||
(errcode(ERRCODE_DATA_EXCEPTION),
|
||||
errmsg("array must be 1-D")));
|
||||
|
||||
if (ARR_HASNULL(array) && array_contains_nulls(array))
|
||||
ereport(ERROR,
|
||||
(errcode(ERRCODE_NULL_VALUE_NOT_ALLOWED),
|
||||
errmsg("array must not contain nulls")));
|
||||
|
||||
get_typlenbyvalalign(ARR_ELEMTYPE(array), &typlen, &typbyval, &typalign);
|
||||
deconstruct_array(array, ARR_ELEMTYPE(array), typlen, typbyval, typalign, &elemsp, NULL, &nelemsp);
|
||||
|
||||
CheckDim(nelemsp);
|
||||
CheckExpectedDim(typmod, nelemsp);
|
||||
|
||||
result = InitHalfVector(nelemsp);
|
||||
|
||||
if (ARR_ELEMTYPE(array) == INT4OID)
|
||||
{
|
||||
for (int i = 0; i < nelemsp; i++)
|
||||
result->x[i] = Float4ToHalf(DatumGetInt32(elemsp[i]));
|
||||
}
|
||||
else if (ARR_ELEMTYPE(array) == FLOAT8OID)
|
||||
{
|
||||
for (int i = 0; i < nelemsp; i++)
|
||||
result->x[i] = Float4ToHalf(DatumGetFloat8(elemsp[i]));
|
||||
}
|
||||
else if (ARR_ELEMTYPE(array) == FLOAT4OID)
|
||||
{
|
||||
for (int i = 0; i < nelemsp; i++)
|
||||
result->x[i] = Float4ToHalf(DatumGetFloat4(elemsp[i]));
|
||||
}
|
||||
else if (ARR_ELEMTYPE(array) == NUMERICOID)
|
||||
{
|
||||
for (int i = 0; i < nelemsp; i++)
|
||||
result->x[i] = Float4ToHalf(DatumGetFloat4(DirectFunctionCall1(numeric_float4, elemsp[i])));
|
||||
}
|
||||
else
|
||||
{
|
||||
ereport(ERROR,
|
||||
(errcode(ERRCODE_DATA_EXCEPTION),
|
||||
errmsg("unsupported array type")));
|
||||
}
|
||||
|
||||
/*
|
||||
* Free allocation from deconstruct_array. Do not free individual elements
|
||||
* when pass-by-reference since they point to original array.
|
||||
*/
|
||||
pfree(elemsp);
|
||||
|
||||
/* Check elements */
|
||||
for (int i = 0; i < result->dim; i++)
|
||||
CheckElement(result->x[i]);
|
||||
|
||||
PG_RETURN_POINTER(result);
|
||||
}
|
||||
|
||||
/*
|
||||
* Convert half vector to float4[]
|
||||
*/
|
||||
PGDLLEXPORT PG_FUNCTION_INFO_V1(halfvec_to_float4);
|
||||
Datum
|
||||
halfvec_to_float4(PG_FUNCTION_ARGS)
|
||||
{
|
||||
HalfVector *vec = PG_GETARG_HALFVEC_P(0);
|
||||
Datum *datums;
|
||||
ArrayType *result;
|
||||
|
||||
datums = (Datum *) palloc(sizeof(Datum) * vec->dim);
|
||||
|
||||
for (int i = 0; i < vec->dim; i++)
|
||||
datums[i] = Float4GetDatum(HalfToFloat4(vec->x[i]));
|
||||
|
||||
/* Use TYPALIGN_INT for float4 */
|
||||
result = construct_array(datums, vec->dim, FLOAT4OID, sizeof(float4), true, TYPALIGN_INT);
|
||||
|
||||
pfree(datums);
|
||||
|
||||
PG_RETURN_POINTER(result);
|
||||
}
|
||||
|
||||
/*
|
||||
* Get the L2 distance between half vectors
|
||||
*/
|
||||
PGDLLEXPORT PG_FUNCTION_INFO_V1(halfvec_l2_distance);
|
||||
Datum
|
||||
halfvec_l2_distance(PG_FUNCTION_ARGS)
|
||||
{
|
||||
HalfVector *a = PG_GETARG_HALFVEC_P(0);
|
||||
HalfVector *b = PG_GETARG_HALFVEC_P(1);
|
||||
half *ax = a->x;
|
||||
half *bx = b->x;
|
||||
float distance = 0.0;
|
||||
|
||||
CheckDims(a, b);
|
||||
|
||||
/* Auto-vectorized */
|
||||
for (int i = 0; i < a->dim; i++)
|
||||
{
|
||||
float diff = HalfToFloat4(ax[i]) - HalfToFloat4(bx[i]);
|
||||
|
||||
distance += diff * diff;
|
||||
}
|
||||
|
||||
PG_RETURN_FLOAT8(sqrt((double) distance));
|
||||
}
|
||||
|
||||
/*
|
||||
* Get the L2 squared distance between half vectors
|
||||
*/
|
||||
PGDLLEXPORT PG_FUNCTION_INFO_V1(halfvec_l2_squared_distance);
|
||||
Datum
|
||||
halfvec_l2_squared_distance(PG_FUNCTION_ARGS)
|
||||
{
|
||||
HalfVector *a = PG_GETARG_HALFVEC_P(0);
|
||||
HalfVector *b = PG_GETARG_HALFVEC_P(1);
|
||||
half *ax = a->x;
|
||||
half *bx = b->x;
|
||||
float distance = 0.0;
|
||||
|
||||
CheckDims(a, b);
|
||||
|
||||
/* Auto-vectorized */
|
||||
for (int i = 0; i < a->dim; i++)
|
||||
{
|
||||
float diff = HalfToFloat4(ax[i]) - HalfToFloat4(bx[i]);
|
||||
|
||||
distance += diff * diff;
|
||||
}
|
||||
|
||||
PG_RETURN_FLOAT8((double) distance);
|
||||
}
|
||||
|
||||
/*
|
||||
* Get the inner product of two half vectors
|
||||
*/
|
||||
PGDLLEXPORT PG_FUNCTION_INFO_V1(halfvec_inner_product);
|
||||
Datum
|
||||
halfvec_inner_product(PG_FUNCTION_ARGS)
|
||||
{
|
||||
HalfVector *a = PG_GETARG_HALFVEC_P(0);
|
||||
HalfVector *b = PG_GETARG_HALFVEC_P(1);
|
||||
half *ax = a->x;
|
||||
half *bx = b->x;
|
||||
float distance = 0.0;
|
||||
|
||||
CheckDims(a, b);
|
||||
|
||||
/* Auto-vectorized */
|
||||
for (int i = 0; i < a->dim; i++)
|
||||
distance += HalfToFloat4(ax[i]) * HalfToFloat4(bx[i]);
|
||||
|
||||
PG_RETURN_FLOAT8((double) distance);
|
||||
}
|
||||
|
||||
/*
|
||||
* Get the negative inner product of two half vectors
|
||||
*/
|
||||
PGDLLEXPORT PG_FUNCTION_INFO_V1(halfvec_negative_inner_product);
|
||||
Datum
|
||||
halfvec_negative_inner_product(PG_FUNCTION_ARGS)
|
||||
{
|
||||
HalfVector *a = PG_GETARG_HALFVEC_P(0);
|
||||
HalfVector *b = PG_GETARG_HALFVEC_P(1);
|
||||
half *ax = a->x;
|
||||
half *bx = b->x;
|
||||
float distance = 0.0;
|
||||
|
||||
CheckDims(a, b);
|
||||
|
||||
/* Auto-vectorized */
|
||||
for (int i = 0; i < a->dim; i++)
|
||||
distance += HalfToFloat4(ax[i]) * HalfToFloat4(bx[i]);
|
||||
|
||||
PG_RETURN_FLOAT8((double) distance * -1);
|
||||
}
|
||||
|
||||
/*
|
||||
* Get the cosine distance between two half vectors
|
||||
*/
|
||||
PGDLLEXPORT PG_FUNCTION_INFO_V1(halfvec_cosine_distance);
|
||||
Datum
|
||||
halfvec_cosine_distance(PG_FUNCTION_ARGS)
|
||||
{
|
||||
HalfVector *a = PG_GETARG_HALFVEC_P(0);
|
||||
HalfVector *b = PG_GETARG_HALFVEC_P(1);
|
||||
half *ax = a->x;
|
||||
half *bx = b->x;
|
||||
float distance = 0.0;
|
||||
float norma = 0.0;
|
||||
float normb = 0.0;
|
||||
double similarity;
|
||||
|
||||
CheckDims(a, b);
|
||||
|
||||
/* Auto-vectorized */
|
||||
for (int i = 0; i < a->dim; i++)
|
||||
{
|
||||
float axi = HalfToFloat4(ax[i]);
|
||||
float bxi = HalfToFloat4(bx[i]);
|
||||
|
||||
distance += axi * bxi;
|
||||
norma += axi * axi;
|
||||
normb += bxi * bxi;
|
||||
}
|
||||
|
||||
/* Use sqrt(a * b) over sqrt(a) * sqrt(b) */
|
||||
similarity = (double) distance / sqrt((double) norma * (double) normb);
|
||||
|
||||
#ifdef _MSC_VER
|
||||
/* /fp:fast may not propagate NaN */
|
||||
if (isnan(similarity))
|
||||
PG_RETURN_FLOAT8(NAN);
|
||||
#endif
|
||||
|
||||
/* Keep in range */
|
||||
if (similarity > 1)
|
||||
similarity = 1;
|
||||
else if (similarity < -1)
|
||||
similarity = -1;
|
||||
|
||||
PG_RETURN_FLOAT8(1 - similarity);
|
||||
}
|
||||
|
||||
/*
|
||||
* Get the L1 distance between two half vectors
|
||||
*/
|
||||
PGDLLEXPORT PG_FUNCTION_INFO_V1(halfvec_l1_distance);
|
||||
Datum
|
||||
halfvec_l1_distance(PG_FUNCTION_ARGS)
|
||||
{
|
||||
HalfVector *a = PG_GETARG_HALFVEC_P(0);
|
||||
HalfVector *b = PG_GETARG_HALFVEC_P(1);
|
||||
half *ax = a->x;
|
||||
half *bx = b->x;
|
||||
float distance = 0.0;
|
||||
|
||||
CheckDims(a, b);
|
||||
|
||||
/* Auto-vectorized */
|
||||
for (int i = 0; i < a->dim; i++)
|
||||
distance += fabsf(HalfToFloat4(ax[i]) - HalfToFloat4(bx[i]));
|
||||
|
||||
PG_RETURN_FLOAT8((double) distance);
|
||||
}
|
||||
|
||||
/*
|
||||
* Get the L2 norm of a half vector
|
||||
*/
|
||||
PGDLLEXPORT PG_FUNCTION_INFO_V1(halfvec_norm);
|
||||
Datum
|
||||
halfvec_norm(PG_FUNCTION_ARGS)
|
||||
{
|
||||
HalfVector *a = PG_GETARG_HALFVEC_P(0);
|
||||
half *ax = a->x;
|
||||
double norm = 0.0;
|
||||
|
||||
/* Auto-vectorized */
|
||||
for (int i = 0; i < a->dim; i++)
|
||||
norm += (double) HalfToFloat4(ax[i]) * (double) HalfToFloat4(ax[i]);
|
||||
|
||||
PG_RETURN_FLOAT8(sqrt(norm));
|
||||
}
|
||||
38
src/halfvec.h
Normal file
38
src/halfvec.h
Normal file
@@ -0,0 +1,38 @@
|
||||
#ifndef HALFVEC_H
|
||||
#define HALFVEC_H
|
||||
|
||||
#define __STDC_WANT_IEC_60559_TYPES_EXT__
|
||||
|
||||
#include <float.h>
|
||||
|
||||
#ifdef __FLT16_MAX__
|
||||
#define FLT16_SUPPORT
|
||||
#endif
|
||||
|
||||
#ifdef FLT16_SUPPORT
|
||||
#define half _Float16
|
||||
#define HALF_MAX FLT16_MAX
|
||||
#else
|
||||
/* TODO #pragma message("")? */
|
||||
#define half uint16
|
||||
#define HALF_MAX 65504
|
||||
#endif
|
||||
|
||||
#define HALFVEC_MAX_DIM 32000
|
||||
|
||||
#define HALFVEC_SIZE(_dim) (offsetof(HalfVector, x) + sizeof(half)*(_dim))
|
||||
#define DatumGetHalfVector(x) ((HalfVector *) PG_DETOAST_DATUM(x))
|
||||
#define PG_GETARG_HALFVEC_P(x) DatumGetHalfVector(PG_GETARG_DATUM(x))
|
||||
#define PG_RETURN_HALFVEC_P(x) PG_RETURN_POINTER(x)
|
||||
|
||||
typedef struct HalfVector
|
||||
{
|
||||
int32 vl_len_; /* varlena header (do not touch directly!) */
|
||||
int16 dim; /* number of dimensions */
|
||||
int16 unused;
|
||||
half x[FLEXIBLE_ARRAY_MEMBER];
|
||||
} HalfVector;
|
||||
|
||||
HalfVector *InitHalfVector(int dim);
|
||||
|
||||
#endif
|
||||
@@ -55,6 +55,10 @@
|
||||
#define HNSW_UPDATE_ENTRY_GREATER 1
|
||||
#define HNSW_UPDATE_ENTRY_ALWAYS 2
|
||||
|
||||
/* Data types */
|
||||
#define HNSW_TYPE_VECTOR 1
|
||||
#define HNSW_TYPE_HALFVEC 2
|
||||
|
||||
/* Build phases */
|
||||
/* PROGRESS_CREATEIDX_SUBPHASE_INITIALIZE is 1 */
|
||||
#define PROGRESS_HNSW_PHASE_LOAD 2
|
||||
@@ -242,6 +246,7 @@ typedef struct HnswBuildState
|
||||
Relation index;
|
||||
IndexInfo *indexInfo;
|
||||
ForkNumber forkNum;
|
||||
int type;
|
||||
|
||||
/* Settings */
|
||||
int dimensions;
|
||||
@@ -262,7 +267,6 @@ typedef struct HnswBuildState
|
||||
HnswGraph *graph;
|
||||
double ml;
|
||||
int maxLevel;
|
||||
Vector *normvec;
|
||||
|
||||
/* Memory */
|
||||
MemoryContext graphCtx;
|
||||
@@ -367,7 +371,8 @@ typedef struct HnswVacuumState
|
||||
int HnswGetM(Relation index);
|
||||
int HnswGetEfConstruction(Relation index);
|
||||
FmgrInfo *HnswOptionalProcInfo(Relation index, uint16 procnum);
|
||||
bool HnswNormValue(FmgrInfo *procinfo, Oid collation, Datum *value, Vector * result);
|
||||
int HnswGetType(Relation index);
|
||||
bool HnswNormValue(FmgrInfo *procinfo, Oid collation, Datum *value, int type);
|
||||
Buffer HnswNewBuffer(Relation index, ForkNumber forkNum);
|
||||
void HnswInitPage(Buffer buf, Page page);
|
||||
void HnswInit(void);
|
||||
|
||||
@@ -489,7 +489,7 @@ InsertTuple(Relation index, Datum *values, bool *isnull, ItemPointer heaptid, Hn
|
||||
/* Normalize if needed */
|
||||
if (buildstate->normprocinfo != NULL)
|
||||
{
|
||||
if (!HnswNormValue(buildstate->normprocinfo, buildstate->collation, &value, buildstate->normvec))
|
||||
if (!HnswNormValue(buildstate->normprocinfo, buildstate->collation, &value, buildstate->type))
|
||||
return false;
|
||||
}
|
||||
|
||||
@@ -671,21 +671,27 @@ HnswSharedMemoryAlloc(Size size, void *state)
|
||||
static void
|
||||
InitBuildState(HnswBuildState * buildstate, Relation heap, Relation index, IndexInfo *indexInfo, ForkNumber forkNum)
|
||||
{
|
||||
int maxDimensions = HNSW_MAX_DIM;
|
||||
|
||||
buildstate->heap = heap;
|
||||
buildstate->index = index;
|
||||
buildstate->indexInfo = indexInfo;
|
||||
buildstate->forkNum = forkNum;
|
||||
buildstate->type = HnswGetType(index);
|
||||
|
||||
buildstate->m = HnswGetM(index);
|
||||
buildstate->efConstruction = HnswGetEfConstruction(index);
|
||||
buildstate->dimensions = TupleDescAttr(index->rd_att, 0)->atttypmod;
|
||||
|
||||
if (buildstate->type == HNSW_TYPE_HALFVEC)
|
||||
maxDimensions *= 2;
|
||||
|
||||
/* Require column to have dimensions to be indexed */
|
||||
if (buildstate->dimensions < 0)
|
||||
elog(ERROR, "column does not have dimensions");
|
||||
|
||||
if (buildstate->dimensions > HNSW_MAX_DIM)
|
||||
elog(ERROR, "column cannot have more than %d dimensions for hnsw index", HNSW_MAX_DIM);
|
||||
if (buildstate->dimensions > maxDimensions)
|
||||
elog(ERROR, "column cannot have more than %d dimensions for hnsw index", maxDimensions);
|
||||
|
||||
if (buildstate->efConstruction < 2 * buildstate->m)
|
||||
elog(ERROR, "ef_construction must be greater than or equal to 2 * m");
|
||||
@@ -703,9 +709,6 @@ InitBuildState(HnswBuildState * buildstate, Relation heap, Relation index, Index
|
||||
buildstate->ml = HnswGetMl(buildstate->m);
|
||||
buildstate->maxLevel = HnswGetMaxLevel(buildstate->m);
|
||||
|
||||
/* Reuse for each tuple */
|
||||
buildstate->normvec = InitVector(buildstate->dimensions);
|
||||
|
||||
buildstate->graphCtx = GenerationContextCreate(CurrentMemoryContext,
|
||||
"Hnsw build graph context",
|
||||
#if PG_VERSION_NUM >= 150000
|
||||
@@ -729,7 +732,6 @@ InitBuildState(HnswBuildState * buildstate, Relation heap, Relation index, Index
|
||||
static void
|
||||
FreeBuildState(HnswBuildState * buildstate)
|
||||
{
|
||||
pfree(buildstate->normvec);
|
||||
MemoryContextDelete(buildstate->graphCtx);
|
||||
MemoryContextDelete(buildstate->tmpCtx);
|
||||
}
|
||||
|
||||
@@ -622,7 +622,7 @@ HnswInsertTuple(Relation index, Datum *values, bool *isnull, ItemPointer heap_ti
|
||||
normprocinfo = HnswOptionalProcInfo(index, HNSW_NORM_PROC);
|
||||
if (normprocinfo != NULL)
|
||||
{
|
||||
if (!HnswNormValue(normprocinfo, collation, &value, NULL))
|
||||
if (!HnswNormValue(normprocinfo, collation, &value, HnswGetType(index)))
|
||||
return;
|
||||
}
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
#include "postgres.h"
|
||||
|
||||
#include "access/relscan.h"
|
||||
#include "halfvec.h"
|
||||
#include "hnsw.h"
|
||||
#include "pgstat.h"
|
||||
#include "storage/bufmgr.h"
|
||||
@@ -73,7 +74,14 @@ GetScanValue(IndexScanDesc scan)
|
||||
Datum value;
|
||||
|
||||
if (scan->orderByData->sk_flags & SK_ISNULL)
|
||||
value = PointerGetDatum(InitVector(GetDimensions(scan->indexRelation)));
|
||||
{
|
||||
int dimensions = GetDimensions(scan->indexRelation);
|
||||
|
||||
if (HnswGetType(scan->indexRelation) == HNSW_TYPE_HALFVEC)
|
||||
value = PointerGetDatum(InitHalfVector(dimensions));
|
||||
else
|
||||
value = PointerGetDatum(InitVector(dimensions));
|
||||
}
|
||||
else
|
||||
{
|
||||
value = scan->orderByData->sk_argument;
|
||||
@@ -84,7 +92,7 @@ GetScanValue(IndexScanDesc scan)
|
||||
|
||||
/* Fine if normalization fails */
|
||||
if (so->normprocinfo != NULL)
|
||||
HnswNormValue(so->normprocinfo, so->collation, &value, NULL);
|
||||
HnswNormValue(so->normprocinfo, so->collation, &value, HnswGetType(scan->indexRelation));
|
||||
}
|
||||
|
||||
return value;
|
||||
|
||||
@@ -3,12 +3,15 @@
|
||||
#include <math.h>
|
||||
|
||||
#include "access/generic_xlog.h"
|
||||
#include "catalog/pg_type.h"
|
||||
#include "halfvec.h"
|
||||
#include "hnsw.h"
|
||||
#include "lib/pairingheap.h"
|
||||
#include "storage/bufmgr.h"
|
||||
#include "utils/datum.h"
|
||||
#include "utils/memdebug.h"
|
||||
#include "utils/rel.h"
|
||||
#include "utils/syscache.h"
|
||||
#include "vector.h"
|
||||
|
||||
#if PG_VERSION_NUM >= 130000
|
||||
@@ -149,6 +152,32 @@ HnswOptionalProcInfo(Relation index, uint16 procnum)
|
||||
return index_getprocinfo(index, 1, procnum);
|
||||
}
|
||||
|
||||
/*
|
||||
* Get type
|
||||
*/
|
||||
int
|
||||
HnswGetType(Relation index)
|
||||
{
|
||||
Oid typeOid = TupleDescAttr(index->rd_att, 0)->atttypid;
|
||||
HeapTuple tuple;
|
||||
Form_pg_type type;
|
||||
int result;
|
||||
|
||||
tuple = SearchSysCache1(TYPEOID, ObjectIdGetDatum(typeOid));
|
||||
if (!HeapTupleIsValid(tuple))
|
||||
elog(ERROR, "cache lookup failed for type %u", typeOid);
|
||||
|
||||
type = (Form_pg_type) GETSTRUCT(tuple);
|
||||
if (strcmp(NameStr(type->typname), "halfvec") == 0)
|
||||
result = HNSW_TYPE_HALFVEC;
|
||||
else
|
||||
result = HNSW_TYPE_VECTOR;
|
||||
|
||||
ReleaseSysCache(tuple);
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
/*
|
||||
* Divide by the norm
|
||||
*
|
||||
@@ -158,21 +187,34 @@ HnswOptionalProcInfo(Relation index, uint16 procnum)
|
||||
* if it's different than the original value
|
||||
*/
|
||||
bool
|
||||
HnswNormValue(FmgrInfo *procinfo, Oid collation, Datum *value, Vector * result)
|
||||
HnswNormValue(FmgrInfo *procinfo, Oid collation, Datum *value, int type)
|
||||
{
|
||||
double norm = DatumGetFloat8(FunctionCall1Coll(procinfo, collation, *value));
|
||||
|
||||
if (norm > 0)
|
||||
{
|
||||
Vector *v = DatumGetVector(*value);
|
||||
if (type == HNSW_TYPE_HALFVEC)
|
||||
{
|
||||
HalfVector *v = DatumGetHalfVector(*value);
|
||||
HalfVector *result = InitHalfVector(v->dim);
|
||||
|
||||
if (result == NULL)
|
||||
result = InitVector(v->dim);
|
||||
for (int i = 0; i < v->dim; i++)
|
||||
result->x[i] = v->x[i] / norm;
|
||||
|
||||
for (int i = 0; i < v->dim; i++)
|
||||
result->x[i] = v->x[i] / norm;
|
||||
*value = PointerGetDatum(result);
|
||||
}
|
||||
else if (type == HNSW_TYPE_VECTOR)
|
||||
{
|
||||
Vector *v = DatumGetVector(*value);
|
||||
Vector *result = InitVector(v->dim);
|
||||
|
||||
*value = PointerGetDatum(result);
|
||||
for (int i = 0; i < v->dim; i++)
|
||||
result->x[i] = v->x[i] / norm;
|
||||
|
||||
*value = PointerGetDatum(result);
|
||||
}
|
||||
else
|
||||
elog(ERROR, "Unsupported type");
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user