Added more functions for halfvec

This commit is contained in:
Andrew Kane
2024-04-14 13:12:08 -07:00
parent 45cea30943
commit c68c2867fd
7 changed files with 604 additions and 2 deletions

View File

@@ -41,6 +41,19 @@ HalfIsInf(half num)
#endif
}
/*
* Check if half is zero
*/
static inline bool
HalfIsZero(half num)
{
#ifdef FLT16_SUPPORT
return num == 0;
#else
return (num & 0x7FFF) == 0x0000;
#endif
}
/*
* Convert a half to a float4
*/

View File

@@ -146,6 +146,24 @@ halfvec_isspace(char ch)
return false;
}
#if PG_VERSION_NUM < 120003
static pg_noinline void
float_overflow_error(void)
{
ereport(ERROR,
(errcode(ERRCODE_NUMERIC_VALUE_OUT_OF_RANGE),
errmsg("value out of range: overflow")));
}
static pg_noinline void
float_underflow_error(void)
{
ereport(ERROR,
(errcode(ERRCODE_NUMERIC_VALUE_OUT_OF_RANGE),
errmsg("value out of range: underflow")));
}
#endif
/*
* Convert textual representation to internal representation
*/
@@ -677,6 +695,18 @@ halfvec_l1_distance(PG_FUNCTION_ARGS)
PG_RETURN_FLOAT8((double) distance);
}
/*
* Get the dimensions of a half vector
*/
PGDLLEXPORT PG_FUNCTION_INFO_V1(halfvec_dims);
Datum
halfvec_dims(PG_FUNCTION_ARGS)
{
HalfVector *a = PG_GETARG_HALFVEC_P(0);
PG_RETURN_INT32(a->dim);
}
/*
* Get the L2 norm of a half vector
*/
@@ -699,6 +729,126 @@ halfvec_norm(PG_FUNCTION_ARGS)
PG_RETURN_FLOAT8(sqrt(norm));
}
/*
* Add half vectors
*/
PGDLLEXPORT PG_FUNCTION_INFO_V1(halfvec_add);
Datum
halfvec_add(PG_FUNCTION_ARGS)
{
HalfVector *a = PG_GETARG_HALFVEC_P(0);
HalfVector *b = PG_GETARG_HALFVEC_P(1);
half *ax = a->x;
half *bx = b->x;
HalfVector *result;
half *rx;
CheckDims(a, b);
result = InitHalfVector(a->dim);
rx = result->x;
/* Auto-vectorized */
for (int i = 0, imax = a->dim; i < imax; i++)
{
#ifdef FLT16_SUPPORT
rx[i] = ax[i] + bx[i];
#else
rx[i] = Float4ToHalfUnchecked(HalfToFloat4(ax[i]) + HalfToFloat4(bx[i]));
#endif
}
/* Check for overflow */
for (int i = 0, imax = a->dim; i < imax; i++)
{
if (HalfIsInf(rx[i]))
float_overflow_error();
}
PG_RETURN_POINTER(result);
}
/*
* Subtract half vectors
*/
PGDLLEXPORT PG_FUNCTION_INFO_V1(halfvec_sub);
Datum
halfvec_sub(PG_FUNCTION_ARGS)
{
HalfVector *a = PG_GETARG_HALFVEC_P(0);
HalfVector *b = PG_GETARG_HALFVEC_P(1);
half *ax = a->x;
half *bx = b->x;
HalfVector *result;
half *rx;
CheckDims(a, b);
result = InitHalfVector(a->dim);
rx = result->x;
/* Auto-vectorized */
for (int i = 0, imax = a->dim; i < imax; i++)
{
#ifdef FLT16_SUPPORT
rx[i] = ax[i] - bx[i];
#else
rx[i] = Float4ToHalfUnchecked(HalfToFloat4(ax[i]) - HalfToFloat4(bx[i]));
#endif
}
/* Check for overflow */
for (int i = 0, imax = a->dim; i < imax; i++)
{
if (HalfIsInf(rx[i]))
float_overflow_error();
}
PG_RETURN_POINTER(result);
}
/*
* Multiply half vectors
*/
PGDLLEXPORT PG_FUNCTION_INFO_V1(halfvec_mul);
Datum
halfvec_mul(PG_FUNCTION_ARGS)
{
HalfVector *a = PG_GETARG_HALFVEC_P(0);
HalfVector *b = PG_GETARG_HALFVEC_P(1);
half *ax = a->x;
half *bx = b->x;
HalfVector *result;
half *rx;
CheckDims(a, b);
result = InitHalfVector(a->dim);
rx = result->x;
/* Auto-vectorized */
for (int i = 0, imax = a->dim; i < imax; i++)
{
#ifdef FLT16_SUPPORT
rx[i] = ax[i] * bx[i];
#else
rx[i] = Float4ToHalfUnchecked(HalfToFloat4(ax[i]) * HalfToFloat4(bx[i]));
#endif
}
/* Check for overflow and underflow */
for (int i = 0, imax = a->dim; i < imax; i++)
{
if (HalfIsInf(rx[i]))
float_overflow_error();
if (HalfIsZero(rx[i]) && !(HalfIsZero(ax[i]) || HalfIsZero(bx[i])))
float_underflow_error();
}
PG_RETURN_POINTER(result);
}
/*
* Quantize a half vector
*/
@@ -775,3 +925,94 @@ halfvec_cmp_internal(HalfVector * a, HalfVector * b)
return 0;
}
/*
* Less than
*/
PGDLLEXPORT PG_FUNCTION_INFO_V1(halfvec_lt);
Datum
halfvec_lt(PG_FUNCTION_ARGS)
{
HalfVector *a = PG_GETARG_HALFVEC_P(0);
HalfVector *b = PG_GETARG_HALFVEC_P(1);
PG_RETURN_BOOL(halfvec_cmp_internal(a, b) < 0);
}
/*
* Less than or equal
*/
PGDLLEXPORT PG_FUNCTION_INFO_V1(halfvec_le);
Datum
halfvec_le(PG_FUNCTION_ARGS)
{
HalfVector *a = PG_GETARG_HALFVEC_P(0);
HalfVector *b = PG_GETARG_HALFVEC_P(1);
PG_RETURN_BOOL(halfvec_cmp_internal(a, b) <= 0);
}
/*
* Equal
*/
PGDLLEXPORT PG_FUNCTION_INFO_V1(halfvec_eq);
Datum
halfvec_eq(PG_FUNCTION_ARGS)
{
HalfVector *a = PG_GETARG_HALFVEC_P(0);
HalfVector *b = PG_GETARG_HALFVEC_P(1);
PG_RETURN_BOOL(halfvec_cmp_internal(a, b) == 0);
}
/*
* Not equal
*/
PGDLLEXPORT PG_FUNCTION_INFO_V1(halfvec_ne);
Datum
halfvec_ne(PG_FUNCTION_ARGS)
{
HalfVector *a = PG_GETARG_HALFVEC_P(0);
HalfVector *b = PG_GETARG_HALFVEC_P(1);
PG_RETURN_BOOL(halfvec_cmp_internal(a, b) != 0);
}
/*
* Greater than or equal
*/
PGDLLEXPORT PG_FUNCTION_INFO_V1(halfvec_ge);
Datum
halfvec_ge(PG_FUNCTION_ARGS)
{
HalfVector *a = PG_GETARG_HALFVEC_P(0);
HalfVector *b = PG_GETARG_HALFVEC_P(1);
PG_RETURN_BOOL(halfvec_cmp_internal(a, b) >= 0);
}
/*
* Greater than
*/
PGDLLEXPORT PG_FUNCTION_INFO_V1(halfvec_gt);
Datum
halfvec_gt(PG_FUNCTION_ARGS)
{
HalfVector *a = PG_GETARG_HALFVEC_P(0);
HalfVector *b = PG_GETARG_HALFVEC_P(1);
PG_RETURN_BOOL(halfvec_cmp_internal(a, b) > 0);
}
/*
* Compare half vectors
*/
PGDLLEXPORT PG_FUNCTION_INFO_V1(halfvec_cmp);
Datum
halfvec_cmp(PG_FUNCTION_ARGS)
{
HalfVector *a = PG_GETARG_HALFVEC_P(0);
HalfVector *b = PG_GETARG_HALFVEC_P(1);
PG_RETURN_INT32(halfvec_cmp_internal(a, b));
}